diff --git a/cns/service/main.go b/cns/service/main.go index 0314092eb7..3304541540 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -677,7 +677,13 @@ func main() { HTTPClient: &http.Client{}, } - httpRestService, err := restserver.NewHTTPRestService(&config, &wireserver.Client{HTTPClient: &http.Client{}}, &wsProxy, nmaClient, + wsclient := &wireserver.Client{ + HostPort: cnsconfig.WireserverIP, + HTTPClient: &http.Client{}, + Logger: logger.Log, + } + + httpRestService, err := restserver.NewHTTPRestService(&config, wsclient, &wsProxy, nmaClient, endpointStateStore, conflistGenerator, homeAzMonitor) if err != nil { logger.Errorf("Failed to create CNS object, err:%v.\n", err) diff --git a/cns/wireserver/client.go b/cns/wireserver/client.go index 43e3c566aa..417e60ef6f 100644 --- a/cns/wireserver/client.go +++ b/cns/wireserver/client.go @@ -6,12 +6,14 @@ import ( "encoding/xml" "io" "net/http" + "net/url" - "github.com/Azure/azure-container-networking/cns/logger" "github.com/pkg/errors" ) -const hostQueryURL = "http://168.63.129.16/machine/plugins?comp=nmagent&type=getinterfaceinfov1" +const ( + WireserverIP = "168.63.129.16" +) type GetNetworkContainerOpts struct { NetworkContainerID string @@ -25,14 +27,34 @@ type do interface { } type Client struct { + HostPort string + HTTPClient do + Logger interface { + Printf(string, ...any) + } +} + +func (c *Client) hostport() string { + return c.HostPort } // GetInterfaces queries interfaces from the wireserver. func (c *Client) GetInterfaces(ctx context.Context) (*GetInterfacesResult, error) { - logger.Printf("[Azure CNS] GetPrimaryInterfaceInfoFromHost") + c.Logger.Printf("[Azure CNS] GetPrimaryInterfaceInfoFromHost") + + q := &url.Values{} + q.Add("comp", "nmagent") + q.Add("type", "getinterfaceinfov1") + + reqURL := &url.URL{ + Scheme: "http", + Host: c.hostport(), + Path: "/machine/plugins", + RawQuery: q.Encode(), + } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, hostQueryURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody) if err != nil { return nil, errors.Wrap(err, "failed to construct request") } @@ -46,7 +68,7 @@ func (c *Client) GetInterfaces(ctx context.Context) (*GetInterfacesResult, error return nil, errors.Wrap(err, "failed to read response body") } - logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %s", string(b)) + c.Logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %s", string(b)) var res GetInterfacesResult if err := xml.NewDecoder(bytes.NewReader(b)).Decode(&res); err != nil { diff --git a/cns/wireserver/client_test.go b/cns/wireserver/client_test.go new file mode 100644 index 0000000000..55ed954eff --- /dev/null +++ b/cns/wireserver/client_test.go @@ -0,0 +1,86 @@ +package wireserver_test + +import ( + "context" + "encoding/xml" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Azure/azure-container-networking/cns/wireserver" +) + +var _ http.RoundTripper = &TestTripper{} + +// TestTripper is a mock implementation of a round tripper that allows clients +// to substitute their own implementation, so that HTTP requests can be +// asserted against and stub responses can be generated. +type TestTripper struct { + RoundTripF func(*http.Request) (*http.Response, error) +} + +func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripF(req) +} + +type NOPLogger struct{} + +func (m *NOPLogger) Printf(_ string, _ ...any) {} + +func TestGetInterfaces(t *testing.T) { + tests := []struct { + name string + hostport string + expURL string + }{ + { + "real ws url", + "168.63.129.16", + "http://168.63.129.16/machine/plugins?comp=nmagent&type=getinterfaceinfov1", + }, + { + "local ws url", + "127.0.0.1:9001", + "http://127.0.0.1:9001/machine/plugins?comp=nmagent&type=getinterfaceinfov1", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + // create a wireserver client using a test tripper so that it can be asserted + // that the correct requests are sent. + var gotURL string + client := &wireserver.Client{ + HostPort: test.hostport, + Logger: &NOPLogger{}, + HTTPClient: &http.Client{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + gotURL = req.URL.String() + rr := httptest.NewRecorder() + resp := wireserver.GetInterfacesResult{} + err := xml.NewEncoder(rr).Encode(&resp) + if err != nil { + t.Fatal("unexpected error encoding mock wireserver response: err:", err) + } + + return rr.Result(), nil + }, + }, + }, + } + + // invoke the endpoint on Wireserver + _, err := client.GetInterfaces(context.TODO()) + if err != nil { + t.Fatal("unexpected error invoking GetInterfaces: err:", err) + } + + if test.expURL != gotURL { + t.Error("received request URL to wireserve does not match expectation:\n\texp:", test.expURL, "\n\tgot:", gotURL) + } + }) + } +}