diff --git a/cns/restserver/api.go b/cns/restserver/api.go index b44c98e36b..ddb042dd13 100644 --- a/cns/restserver/api.go +++ b/cns/restserver/api.go @@ -746,7 +746,7 @@ func (service *HTTPRestService) setOrchestratorType(w http.ResponseWriter, r *ht service.dncPartitionKey = req.DncPartitionKey nodeID = service.state.NodeID - if nodeID == "" || nodeID == req.NodeID { + if nodeID == "" || nodeID == req.NodeID || !service.areNCsPresent() { switch req.OrchestratorType { case cns.ServiceFabric: fallthrough diff --git a/cns/restserver/api_test.go b/cns/restserver/api_test.go index 844cd0d115..749b0b9e73 100644 --- a/cns/restserver/api_test.go +++ b/cns/restserver/api_test.go @@ -26,6 +26,7 @@ import ( acncommon "github.com/Azure/azure-container-networking/common" "github.com/Azure/azure-container-networking/processlock" "github.com/Azure/azure-container-networking/store" + "github.com/stretchr/testify/assert" ) const ( @@ -195,6 +196,102 @@ func TestSetOrchestratorType(t *testing.T) { } } +func FirstByte(b []byte, err error) []byte { + if err != nil { + panic(err) + } + return b +} + +func FirstRequest(req *http.Request, err error) *http.Request { + if err != nil { + panic(err) + } + return req +} + +func TestSetOrchestratorType_NCsPresent(t *testing.T) { + tests := []struct { + name string + service *HTTPRestService + writer *httptest.ResponseRecorder + request *http.Request + response cns.Response + wanthttperror bool + }{ + { + name: "Node already set, and has NCs, so registration should fail", + service: &HTTPRestService{ + state: &httpRestServiceState{ + NodeID: "node1", + ContainerStatus: map[string]containerstatus{ + "nc1": {}, + }, + ContainerIDByOrchestratorContext: map[string]string{ + "nc1": "present", + }, + }, + }, + writer: httptest.NewRecorder(), + request: FirstRequest(http.NewRequestWithContext( + context.TODO(), http.MethodPost, cns.SetOrchestratorType, bytes.NewReader( + FirstByte(json.Marshal( //nolint:errchkjson //inline map, only using returned bytes + cns.SetOrchestratorTypeRequest{ + OrchestratorType: "Kubernetes", + DncPartitionKey: "partition1", + NodeID: "node2", + }))))), + response: cns.Response{ + ReturnCode: types.InvalidRequest, + Message: "Invalid request since this node has already been registered as node1", + }, + wanthttperror: false, + }, + { + name: "Node already set, with no NCs, so registration should succeed", + service: &HTTPRestService{ + state: &httpRestServiceState{ + NodeID: "node1", + }, + }, + writer: httptest.NewRecorder(), + request: FirstRequest(http.NewRequestWithContext( + context.TODO(), http.MethodPost, cns.SetOrchestratorType, bytes.NewReader( + FirstByte(json.Marshal( //nolint:errchkjson //inline map, only using returned bytes + cns.SetOrchestratorTypeRequest{ + OrchestratorType: "Kubernetes", + DncPartitionKey: "partition1", + NodeID: "node2", + }))))), + response: cns.Response{ + ReturnCode: types.Success, + Message: "", + }, + wanthttperror: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + var resp cns.Response + // Since this is global, we have to replace the state + oldstate := svc.state + svc.state = tt.service.state + mux.ServeHTTP(tt.writer, tt.request) + // Replace back old state + svc.state = oldstate + + err := decodeResponse(tt.writer, &resp) + if tt.wanthttperror { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.response, resp) + }) + } +} + func TestCreateNetworkContainer(t *testing.T) { // requires more than 30 seconds to run fmt.Println("Test: TestCreateNetworkContainer") diff --git a/cns/restserver/util.go b/cns/restserver/util.go index f64999fcfe..7711272f3b 100644 --- a/cns/restserver/util.go +++ b/cns/restserver/util.go @@ -600,6 +600,14 @@ func (service *HTTPRestService) getNetworkContainerDetails(networkContainerID st return containerDetails, containerExists } +// areNCsPresent returns true if NCs are present in CNS, false if no NCs are present +func (service *HTTPRestService) areNCsPresent() bool { + if len(service.state.ContainerStatus) == 0 && len(service.state.ContainerIDByOrchestratorContext) == 0 { + return false + } + return true +} + // Check if the network is joined func (service *HTTPRestService) isNetworkJoined(networkID string) bool { namedLock.LockAcquire(stateJoinedNetworks) diff --git a/cns/restserver/util_test.go b/cns/restserver/util_test.go new file mode 100644 index 0000000000..9415fe2c57 --- /dev/null +++ b/cns/restserver/util_test.go @@ -0,0 +1,52 @@ +package restserver + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAreNCsPresent(t *testing.T) { + tests := []struct { + name string + service HTTPRestService + want bool + }{ + { + name: "container status present", + service: HTTPRestService{ + state: &httpRestServiceState{ + ContainerStatus: map[string]containerstatus{ + "nc1": {}, + }, + }, + }, + want: true, + }, + { + name: "containerIDByOrchestorContext present", + service: HTTPRestService{ + state: &httpRestServiceState{ + ContainerIDByOrchestratorContext: map[string]string{ + "nc1": "present", + }, + }, + }, + want: true, + }, + { + name: "neither containerStatus nor containerIDByOrchestratorContext present", + service: HTTPRestService{ + state: &httpRestServiceState{}, + }, + want: false, + }, + } + for _, tt := range tests { //nolint:govet // this mutex copy is to keep a local reference to this variable in the test func closure, and is ok + tt := tt //nolint:govet // this mutex copy is to keep a local reference to this variable in the test func closure, and is ok + t.Run(tt.name, func(t *testing.T) { + got := tt.service.areNCsPresent() + assert.Equal(t, got, tt.want) + }) + } +}