Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cns/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,13 @@ func main() {
HTTPClient: &http.Client{},
}

httpRestService, err := restserver.NewHTTPRestService(&config, &wireserver.Client{HTTPClient: &http.Client{}}, &wsProxy, nmaClient,
wsclient := &wireserver.Client{
HostPort: cnsconfig.WireserverIP,
HTTPClient: &http.Client{},
Logger: logger.Log,
}

httpRestService, err := restserver.NewHTTPRestService(&config, wsclient, &wsProxy, nmaClient,
endpointStateStore, conflistGenerator, homeAzMonitor)
if err != nil {
logger.Errorf("Failed to create CNS object, err:%v.\n", err)
Expand Down
32 changes: 27 additions & 5 deletions cns/wireserver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"encoding/xml"
"io"
"net/http"
"net/url"

"github.com/Azure/azure-container-networking/cns/logger"
"github.com/pkg/errors"
)

const hostQueryURL = "http://168.63.129.16/machine/plugins?comp=nmagent&type=getinterfaceinfov1"
const (
WireserverIP = "168.63.129.16"
)

type GetNetworkContainerOpts struct {
NetworkContainerID string
Expand All @@ -25,14 +27,34 @@ type do interface {
}

type Client struct {
HostPort string

HTTPClient do
Logger interface {
Printf(string, ...any)
}
}

func (c *Client) hostport() string {
return c.HostPort
}

// GetInterfaces queries interfaces from the wireserver.
func (c *Client) GetInterfaces(ctx context.Context) (*GetInterfacesResult, error) {
logger.Printf("[Azure CNS] GetPrimaryInterfaceInfoFromHost")
c.Logger.Printf("[Azure CNS] GetPrimaryInterfaceInfoFromHost")

q := &url.Values{}
q.Add("comp", "nmagent")
q.Add("type", "getinterfaceinfov1")

reqURL := &url.URL{
Scheme: "http",
Host: c.hostport(),
Path: "/machine/plugins",
RawQuery: q.Encode(),
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, hostQueryURL, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody)
if err != nil {
return nil, errors.Wrap(err, "failed to construct request")
}
Expand All @@ -46,7 +68,7 @@ func (c *Client) GetInterfaces(ctx context.Context) (*GetInterfacesResult, error
return nil, errors.Wrap(err, "failed to read response body")
}

logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %s", string(b))
c.Logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %s", string(b))

var res GetInterfacesResult
if err := xml.NewDecoder(bytes.NewReader(b)).Decode(&res); err != nil {
Expand Down
86 changes: 86 additions & 0 deletions cns/wireserver/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package wireserver_test

import (
"context"
"encoding/xml"
"net/http"
"net/http/httptest"
"testing"

"github.com/Azure/azure-container-networking/cns/wireserver"
)

var _ http.RoundTripper = &TestTripper{}

// TestTripper is a mock implementation of a round tripper that allows clients
// to substitute their own implementation, so that HTTP requests can be
// asserted against and stub responses can be generated.
type TestTripper struct {
RoundTripF func(*http.Request) (*http.Response, error)
}

func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.RoundTripF(req)
}

type NOPLogger struct{}

func (m *NOPLogger) Printf(_ string, _ ...any) {}

func TestGetInterfaces(t *testing.T) {
tests := []struct {
name string
hostport string
expURL string
}{
{
"real ws url",
"168.63.129.16",
"http://168.63.129.16/machine/plugins?comp=nmagent&type=getinterfaceinfov1",
},
{
"local ws url",
"127.0.0.1:9001",
"http://127.0.0.1:9001/machine/plugins?comp=nmagent&type=getinterfaceinfov1",
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
// create a wireserver client using a test tripper so that it can be asserted
// that the correct requests are sent.
var gotURL string
client := &wireserver.Client{
HostPort: test.hostport,
Logger: &NOPLogger{},
HTTPClient: &http.Client{
Transport: &TestTripper{
RoundTripF: func(req *http.Request) (*http.Response, error) {
gotURL = req.URL.String()
rr := httptest.NewRecorder()
resp := wireserver.GetInterfacesResult{}
err := xml.NewEncoder(rr).Encode(&resp)
if err != nil {
t.Fatal("unexpected error encoding mock wireserver response: err:", err)
}

return rr.Result(), nil
},
},
},
}

// invoke the endpoint on Wireserver
_, err := client.GetInterfaces(context.TODO())
if err != nil {
t.Fatal("unexpected error invoking GetInterfaces: err:", err)
}

if test.expURL != gotURL {
t.Error("received request URL to wireserve does not match expectation:\n\texp:", test.expURL, "\n\tgot:", gotURL)
}
})
}
}