diff --git a/cns/service/main.go b/cns/service/main.go index d80b90b286..81890232aa 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -372,8 +372,12 @@ type NodeInterrogator interface { SupportedAPIs(context.Context) ([]string, error) } +type httpDoer interface { + Do(req *http.Request) (*http.Response, error) +} + // RegisterNode - Tries to register node with DNC when CNS is started in managed DNC mode -func registerNode(httpc *http.Client, httpRestService cns.HTTPService, dncEP, infraVnet, nodeID string, ni NodeInterrogator) error { +func registerNode(ctx context.Context, httpClient httpDoer, httpRestService cns.HTTPService, dncEP, infraVnet, nodeID string, ni NodeInterrogator) error { logger.Printf("[Azure CNS] Registering node %s with Infrastructure Network: %s PrivateEndpoint: %s", nodeID, infraVnet, dncEP) var ( @@ -386,12 +390,11 @@ func registerNode(httpc *http.Client, httpRestService cns.HTTPService, dncEP, in supportedApis, retErr := ni.SupportedAPIs(context.TODO()) if retErr != nil { - logger.Errorf("[Azure CNS] Failed to retrieve SupportedApis from NMagent of node %s with Infrastructure Network: %s PrivateEndpoint: %s", - nodeID, infraVnet, dncEP) - return retErr + return errors.Wrap(retErr, fmt.Sprintf("[Azure CNS] Failed to retrieve SupportedApis from NMagent of node %s with Infrastructure Network: %s PrivateEndpoint: %s", + nodeID, infraVnet, dncEP)) } - // To avoid any null-pointer deferencing errors. + // To avoid any null-pointer de-referencing errors. if supportedApis == nil { supportedApis = []string{} } @@ -400,7 +403,7 @@ func registerNode(httpc *http.Client, httpRestService cns.HTTPService, dncEP, in // CNS tries to register Node for maximum of an hour. err := retry.Do(func() error { - return sendRegisterNodeRequest(httpc, httpRestService, nodeRegisterRequest, url) + return errors.Wrap(sendRegisterNodeRequest(ctx, httpClient, httpRestService, nodeRegisterRequest, url), "failed to sendRegisterNodeRequest") }, retry.Delay(acn.FiveSeconds), retry.Attempts(maxRetryNodeRegister), retry.DelayType(retry.FixedDelay)) return errors.Wrap(err, fmt.Sprintf("[Azure CNS] Failed to register node %s after maximum reties for an hour with Infrastructure Network: %s PrivateEndpoint: %s", @@ -408,22 +411,28 @@ func registerNode(httpc *http.Client, httpRestService cns.HTTPService, dncEP, in } // sendRegisterNodeRequest func helps in registering the node until there is an error. -func sendRegisterNodeRequest(httpc *http.Client, httpRestService cns.HTTPService, nodeRegisterRequest cns.NodeRegisterRequest, registerURL string) error { +func sendRegisterNodeRequest(ctx context.Context, httpClient httpDoer, httpRestService cns.HTTPService, nodeRegisterRequest cns.NodeRegisterRequest, registerURL string) error { var body bytes.Buffer err := json.NewEncoder(&body).Encode(nodeRegisterRequest) if err != nil { - log.Errorf("[Azure CNS] Failed to register node while encoding json failed with non-retriable err %v", err) + log.Errorf("[Azure CNS] Failed to register node while encoding json failed with non-retryable err %v", err) return errors.Wrap(retry.Unrecoverable(err), "failed to sendRegisterNodeRequest") } - response, err := httpc.Post(registerURL, "application/json", &body) + request, err := http.NewRequestWithContext(ctx, http.MethodPost, registerURL, &body) if err != nil { - logger.Errorf("[Azure CNS] Failed to register node with retriable err: %+v", err) - return errors.Wrap(err, "failed to sendRegisterNodeRequest") + return errors.Wrap(err, "failed to build request") + } + + request.Header.Set("Content-Type", "application/json") + response, err := httpClient.Do(request) + if err != nil { + return errors.Wrap(err, "http request failed") } + defer response.Body.Close() - if response.StatusCode != http.StatusCreated { + if response.StatusCode != http.StatusOK { err = fmt.Errorf("[Azure CNS] Failed to register node, DNC replied with http status code %s", strconv.Itoa(response.StatusCode)) logger.Errorf(err.Error()) return errors.Wrap(err, "failed to sendRegisterNodeRequest") @@ -432,7 +441,7 @@ func sendRegisterNodeRequest(httpc *http.Client, httpRestService cns.HTTPService var req cns.SetOrchestratorTypeRequest err = json.NewDecoder(response.Body).Decode(&req) if err != nil { - log.Errorf("[Azure CNS] decoding Node Resgister response json failed with err %v", err) + log.Errorf("[Azure CNS] decoding Node Register response json failed with err %v", err) return errors.Wrap(err, "failed to sendRegisterNodeRequest") } httpRestService.SetNodeOrchestrator(&req) @@ -791,7 +800,7 @@ func main() { } // We might be configured to reinitialize state from the CNI instead of the apiserver. - // If so, we should check that the the CNI is new enough to support the state commands, + // If so, we should check that the CNI is new enough to support the state commands, // otherwise we fall back to the existing behavior. if cnsconfig.InitializeFromCNI { var isGoodVer bool @@ -891,9 +900,12 @@ func main() { httpRestService.SetOption(acn.OptInfrastructureNetworkID, infravnet) httpRestService.SetOption(acn.OptNodeID, nodeID) - registerErr := registerNode(acn.GetHttpClient(), httpRestService, privateEndpoint, infravnet, nodeID, nmaClient) + // Passing in the default http client that already implements Do function + standardClient := http.DefaultClient + + registerErr := registerNode(rootCtx, standardClient, httpRestService, privateEndpoint, infravnet, nodeID, nmaClient) if registerErr != nil { - logger.Errorf("[Azure CNS] Resgistering Node failed with error: %v PrivateEndpoint: %s InfrastructureNetworkID: %s NodeID: %s", + logger.Errorf("[Azure CNS] Registering Node failed with error: %v PrivateEndpoint: %s InfrastructureNetworkID: %s NodeID: %s", registerErr, privateEndpoint, infravnet, diff --git a/cns/service/main_test.go b/cns/service/main_test.go new file mode 100644 index 0000000000..42a64df761 --- /dev/null +++ b/cns/service/main_test.go @@ -0,0 +1,71 @@ +package main + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/Azure/azure-container-networking/cns" + "github.com/Azure/azure-container-networking/cns/fakes" + "github.com/Azure/azure-container-networking/cns/logger" + "github.com/stretchr/testify/assert" +) + +// MockHTTPClient is a mock implementation of HTTPClient +type MockHTTPClient struct { + Response *http.Response + Err error +} + +// Post is the implementation of the Post method for MockHTTPClient +func (m *MockHTTPClient) Do(_ *http.Request) (*http.Response, error) { + return m.Response, m.Err +} + +func TestSendRegisterNodeRequest_StatusOK(t *testing.T) { + ctx := context.Background() + logger.InitLogger("testlogs", 0, 0, "./") + httpServiceFake := fakes.NewHTTPServiceFake() + nodeRegisterReq := cns.NodeRegisterRequest{ + NumCores: 2, + NmAgentSupportedApis: nil, + } + + url := "https://localhost:9000/api" + + // Create a mock HTTP client + mockResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{"status": "success", "OrchestratorType": "Kubernetes", "DncPartitionKey": "1234", "NodeID": "5678"}`)), + Header: make(http.Header), + } + + mockClient := &MockHTTPClient{Response: mockResponse, Err: nil} + + assert.NoError(t, sendRegisterNodeRequest(ctx, mockClient, httpServiceFake, nodeRegisterReq, url)) +} + +func TestSendRegisterNodeRequest_StatusAccepted(t *testing.T) { + ctx := context.Background() + logger.InitLogger("testlogs", 0, 0, "./") + httpServiceFake := fakes.NewHTTPServiceFake() + nodeRegisterReq := cns.NodeRegisterRequest{ + NumCores: 2, + NmAgentSupportedApis: nil, + } + + url := "https://localhost:9000/api" + + // Create a mock HTTP client + mockResponse := &http.Response{ + StatusCode: http.StatusAccepted, + Body: io.NopCloser(bytes.NewBufferString(`{"status": "accepted", "OrchestratorType": "Kubernetes", "DncPartitionKey": "1234", "NodeID": "5678"}`)), + Header: make(http.Header), + } + + mockClient := &MockHTTPClient{Response: mockResponse, Err: nil} + + assert.Error(t, sendRegisterNodeRequest(ctx, mockClient, httpServiceFake, nodeRegisterReq, url)) +}