diff --git a/cns/restserver/api.go b/cns/restserver/api.go index c6a56e9cd9..0378584c17 100644 --- a/cns/restserver/api.go +++ b/cns/restserver/api.go @@ -1124,7 +1124,7 @@ func getAuthTokenAndInterfaceIDFromNcURL(networkContainerURL string) (*cns.Netwo } //nolint:revive // the previous receiver naming "service" is bad, this is correct: -func (h *HTTPRestService) doPublish(ctx context.Context, req cns.PublishNetworkContainerRequest, ncParameters *cns.NetworkContainerParameters) (string, types.ResponseCode) { +func (h *HTTPRestService) doPublish(ctx context.Context, req cns.PublishNetworkContainerRequest, ncParameters *cns.NetworkContainerParameters) (msg string, code types.ResponseCode, status int) { innerReqBytes := req.CreateNetworkContainerRequestBody var innerReq nmagent.PutNetworkContainerRequest @@ -1133,7 +1133,7 @@ func (h *HTTPRestService) doPublish(ctx context.Context, req cns.PublishNetworkC returnMessage := fmt.Sprintf("Failed to unmarshal embedded NC publish request for NC %s, with err: %v", req.NetworkContainerID, err) returnCode := types.NetworkContainerPublishFailed logger.Errorf("[Azure-CNS] %s", returnMessage) - return returnMessage, returnCode + return returnMessage, returnCode, http.StatusInternalServerError } innerReq.AuthenticationToken = ncParameters.AuthToken @@ -1146,10 +1146,14 @@ func (h *HTTPRestService) doPublish(ctx context.Context, req cns.PublishNetworkC returnMessage := fmt.Sprintf("Failed to publish Network Container %s in put Network Container call, with err: %v", req.NetworkContainerID, err) returnCode := types.NetworkContainerPublishFailed logger.Errorf("[Azure-CNS] %s", returnMessage) - return returnMessage, returnCode + var nmaErr nmagent.Error + if errors.As(err, &nmaErr) { + return returnMessage, returnCode, nmaErr.StatusCode() + } + return returnMessage, returnCode, http.StatusInternalServerError } - return "", types.Success + return "", types.Success, http.StatusOK } // Publish Network Container by calling nmagent @@ -1231,7 +1235,7 @@ func (service *HTTPRestService) publishNetworkContainer(w http.ResponseWriter, r if isNetworkJoined { // Publish Network Container - returnMessage, returnCode = service.doPublish(ctx, req, ncParameters) + returnMessage, returnCode, publishStatusCode = service.doPublish(ctx, req, ncParameters) } default: @@ -1342,6 +1346,10 @@ func (service *HTTPRestService) unpublishNetworkContainer(w http.ResponseWriter, if err != nil { returnMessage = fmt.Sprintf("Failed to unpublish Network Container: %s", req.NetworkContainerID) returnCode = types.NetworkContainerUnpublishFailed + var nmaErr nmagent.Error + if errors.As(err, &nmaErr) { + unpublishStatusCode = nmaErr.StatusCode() + } logger.Errorf("[Azure-CNS] %s", returnMessage) } } diff --git a/cns/restserver/api_test.go b/cns/restserver/api_test.go index c3a8a5bd40..7823489316 100644 --- a/cns/restserver/api_test.go +++ b/cns/restserver/api_test.go @@ -847,6 +847,93 @@ func TestPublishNCBadBody(t *testing.T) { } } +func TestPublishNC401(t *testing.T) { + mnma := &fakes.NMAgentClientFake{ + PutNetworkContainerF: func(_ context.Context, _ *nmagent.PutNetworkContainerRequest) error { + return nmagent.Error{ + Code: http.StatusUnauthorized, + Source: "nmagent", + } + }, + JoinNetworkF: func(_ context.Context, _ nmagent.JoinNetworkRequest) error { + return nil + }, + } + + cleanup := setMockNMAgent(svc, mnma) + t.Cleanup(cleanup) + + joinNetworkURL := "http://" + nmagentEndpoint + "/dummyVnetURL" + + createNetworkContainerURL := "http://" + nmagentEndpoint + + "/machine/plugins/?comp=nmagent&type=NetworkManagement/interfaces/dummyIntf/networkContainers/dummyNCURL/authenticationToken/dummyT/api-version/1" + publishNCRequest := &cns.PublishNetworkContainerRequest{ + NetworkID: "foo", + NetworkContainerID: "bar", + JoinNetworkURL: joinNetworkURL, + CreateNetworkContainerURL: createNetworkContainerURL, + CreateNetworkContainerRequestBody: []byte("{\"version\":\"0\"}"), + } + + var body bytes.Buffer + err := json.NewEncoder(&body).Encode(publishNCRequest) + if err != nil { + t.Fatal("error encoding json: err:", err) + } + + //nolint:noctx // also just a test + req, err := http.NewRequest(http.MethodPost, cns.PublishNetworkContainer, &body) + if err != nil { + t.Fatal("error creating new HTTP request: err:", err) + } + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + expStatus := http.StatusOK + gotStatus := w.Code + if expStatus != gotStatus { + t.Error("unexpected http status code: exp:", expStatus, "got:", gotStatus) + } + + var resp cns.PublishNetworkContainerResponse + //nolint:bodyclose // unnnecessary in a test + err = json.NewDecoder(w.Result().Body).Decode(&resp) + if err != nil { + t.Fatal("unexpected error decoding JSON: err:", err) + } + + expCode := types.NetworkContainerPublishFailed + gotCode := resp.Response.ReturnCode + if expCode != gotCode { + t.Error("unexpected return code: exp:", expCode, "got:", gotCode) + } + + expBodyStatus := http.StatusUnauthorized + gotBodyStatus := resp.PublishStatusCode + if expBodyStatus != gotBodyStatus { + t.Error("unexpected publish body status: exp:", expBodyStatus, "got:", gotBodyStatus) + } + + // ensure that the PublishResponseBody is JSON + pubResp := make(map[string]any) + err = json.Unmarshal(resp.PublishResponseBody, &pubResp) + if err != nil { + t.Fatal("unexpected error unmarshaling PublishResponseBody: err:", err) + } + + // ensure that the PublishResponseBody also contains the embedded status from + // NMAgent + expStatusStr := strconv.Itoa(expBodyStatus) + if gotStatusStr, ok := pubResp["httpStatusCode"]; ok { + if gotStatusStr != expStatusStr { + t.Fatalf("expected PublishResponseBody's httpStatusCode to be %q, but was %q\n", expStatusStr, gotStatusStr) + } + } else { + t.Fatal("PublishResponseBody did not contain httpStatusCode") + } +} + func publishNCViaCNS( networkID, networkContainerID, @@ -1013,6 +1100,89 @@ func TestUnpublishNCViaCNS(t *testing.T) { } } +func TestUnpublishNCViaCNS401(t *testing.T) { + mnma := &fakes.NMAgentClientFake{ + DeleteNetworkContainerF: func(_ context.Context, _ nmagent.DeleteContainerRequest) error { + // simulate a 401 from just Delete + return nmagent.Error{ + Code: http.StatusUnauthorized, + Source: "nmagent", + } + }, + JoinNetworkF: func(_ context.Context, _ nmagent.JoinNetworkRequest) error { + return nil + }, + } + + cleanup := setMockNMAgent(svc, mnma) + t.Cleanup(cleanup) + + deleteNetworkContainerURL := "http://" + nmagentEndpoint + + "/machine/plugins/?comp=nmagent&type=NetworkManagement/interfaces/dummyIntf/networkContainers/dummyNCURL/authenticationToken/dummyT/api-version/1/method/DELETE" + + networkContainerID := "ethWebApp" + networkID := "vnet1" + + joinNetworkURL := "http://" + nmagentEndpoint + "/dummyVnetURL" + + unpublishNCRequest := &cns.UnpublishNetworkContainerRequest{ + NetworkID: networkID, + NetworkContainerID: networkContainerID, + JoinNetworkURL: joinNetworkURL, + DeleteNetworkContainerURL: deleteNetworkContainerURL, + } + + var body bytes.Buffer + err := json.NewEncoder(&body).Encode(unpublishNCRequest) + if err != nil { + t.Fatal("error encoding unpublish request: err:", err) + } + + //nolint:noctx // not important in a test + req, err := http.NewRequest(http.MethodPost, cns.UnpublishNetworkContainer, &body) + if err != nil { + t.Fatal("error submitting unpublish request: err:", err) + } + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + var resp cns.UnpublishNetworkContainerResponse + err = decodeResponse(w, &resp) + if err != nil { + t.Fatal("error decoding json: err:", err) + } + + expCode := types.NetworkContainerUnpublishFailed + if gotCode := resp.Response.ReturnCode; gotCode != expCode { + t.Error("unexpected return code: got:", gotCode, "exp:", expCode) + } + + gotStatus := resp.UnpublishStatusCode + expStatus := http.StatusUnauthorized + if gotStatus != expStatus { + t.Error("unexpected http status during unpublish: got:", gotStatus, "exp:", expStatus) + } + + nmaBody := struct { + StatusCode string `json:"httpStatusCode"` + }{} + err = json.Unmarshal(resp.UnpublishResponseBody, &nmaBody) + if err != nil { + t.Fatal("error decoding UnpublishResponseBody as JSON: err:", err) + } + + gotBodyStatus, err := strconv.Atoi(nmaBody.StatusCode) + if err != nil { + t.Fatal("error parsing NMAgent body status code as an integer: err:", err) + } + + expBodyStatus := http.StatusUnauthorized + if gotBodyStatus != expBodyStatus { + t.Errorf("mismatch between expected NMAgent status code (%d) and NMAgent body status code (%d)\n", expBodyStatus, gotBodyStatus) + } +} + func unpublishNCViaCNS(networkID, networkContainerID, deleteNetworkContainerURL string) error { joinNetworkURL := "http://" + nmagentEndpoint + "/dummyVnetURL"