diff --git a/Makefile b/Makefile index 2dce8cc90b..bbba1ba616 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,7 @@ CNSFILES = \ $(wildcard cns/imdsclient/*.go) \ $(wildcard cns/ipamclient/*.go) \ $(wildcard cns/hnsclient/*.go) \ + $(wildcard cns/nmagentclient/*.go) \ $(wildcard cns/restserver/*.go) \ $(wildcard cns/routes/*.go) \ $(wildcard cns/service/*.go) \ diff --git a/cns/NetworkContainerContract.go b/cns/NetworkContainerContract.go index 55d82e7118..a328f95a82 100644 --- a/cns/NetworkContainerContract.go +++ b/cns/NetworkContainerContract.go @@ -1,6 +1,8 @@ package cns -import "encoding/json" +import ( + "encoding/json" +) // Container Network Service DNC Contract const ( @@ -8,6 +10,8 @@ const ( CreateOrUpdateNetworkContainer = "/network/createorupdatenetworkcontainer" DeleteNetworkContainer = "/network/deletenetworkcontainer" GetNetworkContainerStatus = "/network/getnetworkcontainerstatus" + PublishNetworkContainer = "/network/publishnetworkcontainer" + UnpublishNetworkContainer = "/network/unpublishnetworkcontainer" GetInterfaceForContainer = "/network/getinterfaceforcontainer" GetNetworkContainerByOrchestratorContext = "/network/getnetworkcontainerbyorchestratorcontext" AttachContainerToNetwork = "/network/attachcontainertonetwork" @@ -182,3 +186,36 @@ type NetworkInterface struct { Name string IPAddress string } + +// PublishNetworkContainerRequest specifies request to publish network container via NMAgent. +type PublishNetworkContainerRequest struct { + NetworkID string + NetworkContainerID string + JoinNetworkURL string + CreateNetworkContainerURL string + CreateNetworkContainerRequestBody []byte +} + +// PublishNetworkContainerResponse specifies the response to publish network container request. +type PublishNetworkContainerResponse struct { + Response Response + PublishErrorStr string + PublishStatusCode int + PublishResponseBody []byte +} + +// UnpublishNetworkContainerRequest specifies request to unpublish network container via NMAgent. +type UnpublishNetworkContainerRequest struct { + NetworkID string + NetworkContainerID string + JoinNetworkURL string + DeleteNetworkContainerURL string +} + +// UnpublishNetworkContainerResponse specifies the response to unpublish network container request. +type UnpublishNetworkContainerResponse struct { + Response Response + UnpublishErrorStr string + UnpublishStatusCode int + UnpublishResponseBody []byte +} diff --git a/cns/nmagentclient/nmagentclient.go b/cns/nmagentclient/nmagentclient.go new file mode 100644 index 0000000000..9dcd7aa6ff --- /dev/null +++ b/cns/nmagentclient/nmagentclient.go @@ -0,0 +1,64 @@ +package nmagentclient + +import ( + "bytes" + "encoding/json" + "net/http" + + "github.com/Azure/azure-container-networking/common" + "github.com/Azure/azure-container-networking/log" +) + +// JoinNetwork joins the given network +func JoinNetwork( + networkID string, + joinNetworkURL string) (*http.Response, error) { + log.Printf("[NMAgentClient] JoinNetwork: %s", networkID) + + // Empty body is required as wireserver cannot handle a post without the body. + var body bytes.Buffer + json.NewEncoder(&body).Encode("") + response, err := common.GetHttpClient().Post(joinNetworkURL, "application/json", &body) + + if err == nil && response.StatusCode == http.StatusOK { + defer response.Body.Close() + } + + log.Printf("[NMAgentClient][Response] Join network: %s. Response: %+v. Error: %v", + networkID, response, err) + + return response, err +} + +// PublishNetworkContainer publishes given network container +func PublishNetworkContainer( + networkContainerID string, + createNetworkContainerURL string, + requestBodyData []byte) (*http.Response, error) { + log.Printf("[NMAgentClient] PublishNetworkContainer NC: %s", networkContainerID) + + requestBody := bytes.NewBuffer(requestBodyData) + response, err := common.GetHttpClient().Post(createNetworkContainerURL, "application/json", requestBody) + + log.Printf("[NMAgentClient][Response] Publish NC: %s. Response: %+v. Error: %v", + networkContainerID, response, err) + + return response, err +} + +// UnpublishNetworkContainer unpublishes given network container +func UnpublishNetworkContainer( + networkContainerID string, + deleteNetworkContainerURL string) (*http.Response, error) { + log.Printf("[NMAgentClient] UnpublishNetworkContainer NC: %s", networkContainerID) + + // Empty body is required as wireserver cannot handle a post without the body. + var body bytes.Buffer + json.NewEncoder(&body).Encode("") + response, err := common.GetHttpClient().Post(deleteNetworkContainerURL, "application/json", &body) + + log.Printf("[NMAgentClient][Response] Unpublish NC: %s. Response: %+v. Error: %v", + networkContainerID, response, err) + + return response, err +} diff --git a/cns/restserver/api.go b/cns/restserver/api.go index f2a4fcbf2d..586dd1c2bc 100644 --- a/cns/restserver/api.go +++ b/cns/restserver/api.go @@ -24,6 +24,9 @@ const ( UnsupportedVerb = 21 UnsupportedNetworkContainerType = 22 InvalidRequest = 23 + NetworkJoinFailed = 24 + NetworkContainerPublishFailed = 25 + NetworkContainerUnpublishFailed = 26 UnexpectedError = 99 ) diff --git a/cns/restserver/restserver.go b/cns/restserver/restserver.go index f38c625d62..578eed6780 100644 --- a/cns/restserver/restserver.go +++ b/cns/restserver/restserver.go @@ -6,6 +6,7 @@ package restserver import ( "encoding/json" "fmt" + "io/ioutil" "net" "net/http" "runtime" @@ -19,6 +20,7 @@ import ( "github.com/Azure/azure-container-networking/cns/imdsclient" "github.com/Azure/azure-container-networking/cns/ipamclient" "github.com/Azure/azure-container-networking/cns/networkcontainers" + "github.com/Azure/azure-container-networking/cns/nmagentclient" "github.com/Azure/azure-container-networking/cns/routes" acn "github.com/Azure/azure-container-networking/common" "github.com/Azure/azure-container-networking/log" @@ -26,12 +28,19 @@ import ( "github.com/Azure/azure-container-networking/store" ) +var ( + // Named Lock for accessing different states in httpRestServiceState + namedLock = acn.InitNamedLock() +) + const ( // Key against which CNS state is persisted. storeKey = "ContainerNetworkService" swiftAPIVersion = "1" attach = "Attach" detach = "Detach" + // Rest service state identifier for named lock + stateJoinedNetworks = "JoinedNetworks" ) // HTTPRestService represents http listener for CNS - Container Networking Service. @@ -66,6 +75,7 @@ type httpRestServiceState struct { ContainerStatus map[string]containerstatus // NetworkContainerID is key. Networks map[string]*networkInfo TimeStamp time.Time + joinedNetworks map[string]struct{} } type networkInfo struct { @@ -161,6 +171,8 @@ func (service *HTTPRestService) Start(config *common.ServiceConfig) error { listener.AddHandler(cns.NumberOfCPUCoresPath, service.getNumberOfCPUCores) listener.AddHandler(cns.CreateHostNCApipaEndpointPath, service.createHostNCApipaEndpoint) listener.AddHandler(cns.DeleteHostNCApipaEndpointPath, service.deleteHostNCApipaEndpoint) + listener.AddHandler(cns.PublishNetworkContainer, service.publishNetworkContainer) + listener.AddHandler(cns.UnpublishNetworkContainer, service.unpublishNetworkContainer) // handlers for v0.2 listener.AddHandler(cns.V2Prefix+cns.SetEnvironmentPath, service.setEnvironment) @@ -185,6 +197,11 @@ func (service *HTTPRestService) Start(config *common.ServiceConfig) error { listener.AddHandler(cns.V2Prefix+cns.CreateHostNCApipaEndpointPath, service.createHostNCApipaEndpoint) listener.AddHandler(cns.V2Prefix+cns.DeleteHostNCApipaEndpointPath, service.deleteHostNCApipaEndpoint) + // Initialize HTTP client to be reused in CNS + connectionTimeout, _ := service.GetOption(acn.OptHttpConnectionTimeout).(int) + responseHeaderTimeout, _ := service.GetOption(acn.OptHttpResponseHeaderTimeout).(int) + acn.InitHttpClient(connectionTimeout, responseHeaderTimeout) + log.Printf("[Azure CNS] Listening.") return nil } @@ -1723,3 +1740,218 @@ func (service *HTTPRestService) deleteHostNCApipaEndpoint(w http.ResponseWriter, err = service.Listener.Encode(w, &response) log.Response(service.Name, response, response.Response.ReturnCode, ReturnCodeToString(response.Response.ReturnCode), err) } + +// Check if the network is joined +func (service *HTTPRestService) isNetworkJoined(networkID string) bool { + namedLock.LockAcquire(stateJoinedNetworks) + defer namedLock.LockRelease(stateJoinedNetworks) + + if service.state.joinedNetworks == nil { + service.state.joinedNetworks = make(map[string]struct{}) + } + + _, exists := service.state.joinedNetworks[networkID] + + return exists +} + +// Set the network as joined +func (service *HTTPRestService) setNetworkStateJoined(networkID string) { + namedLock.LockAcquire(stateJoinedNetworks) + defer namedLock.LockRelease(stateJoinedNetworks) + + service.state.joinedNetworks[networkID] = struct{}{} +} + +// Join Network by calling nmagent +func (service *HTTPRestService) joinNetwork( + networkID string, + joinNetworkURL string) (*http.Response, error, error) { + var err error + joinResponse, joinErr := nmagentclient.JoinNetwork( + networkID, + joinNetworkURL) + + if joinErr == nil && joinResponse.StatusCode == http.StatusOK { + // Network joined successfully + service.setNetworkStateJoined(networkID) + log.Printf("[Azure-CNS] setNetworkStateJoined for network: %s", networkID) + } else { + err = fmt.Errorf("Failed to join network: %s", networkID) + } + + return joinResponse, joinErr, err +} + +// Publish Network Container by calling nmagent +func (service *HTTPRestService) publishNetworkContainer(w http.ResponseWriter, r *http.Request) { + log.Printf("[Azure-CNS] PublishNetworkContainer") + + var ( + err error + req cns.PublishNetworkContainerRequest + returnCode int + returnMessage string + publishResponse *http.Response + publishStatusCode int + publishResponseBody []byte + publishError error + publishErrorStr string + isNetworkJoined bool + ) + + err = service.Listener.Decode(w, r, &req) + log.Request(service.Name, &req, err) + if err != nil { + return + } + + switch r.Method { + case "POST": + // Join Network if not joined already + isNetworkJoined = service.isNetworkJoined(req.NetworkID) + if !isNetworkJoined { + publishResponse, publishError, err = service.joinNetwork(req.NetworkID, req.JoinNetworkURL) + if err == nil { + isNetworkJoined = true + } else { + returnMessage = err.Error() + returnCode = NetworkJoinFailed + } + } + + if isNetworkJoined { + // Publish Network Container + publishResponse, publishError = nmagentclient.PublishNetworkContainer( + req.NetworkContainerID, + req.CreateNetworkContainerURL, + req.CreateNetworkContainerRequestBody) + if publishError != nil || publishResponse.StatusCode != http.StatusOK { + returnMessage = fmt.Sprintf("Failed to publish Network Container: %s", req.NetworkContainerID) + returnCode = NetworkContainerPublishFailed + log.Errorf("[Azure-CNS] %s", returnMessage) + } + } + default: + returnMessage = "PublishNetworkContainer API expects a POST" + returnCode = UnsupportedVerb + } + + if publishError != nil { + publishErrorStr = publishError.Error() + } + + if publishResponse != nil { + publishStatusCode = publishResponse.StatusCode + + var errParse error + publishResponseBody, errParse = ioutil.ReadAll(publishResponse.Body) + if errParse != nil { + returnMessage = fmt.Sprintf("Failed to parse the publish body. Error: %v", errParse) + returnCode = UnexpectedError + log.Errorf("[Azure-CNS] %s", returnMessage) + } + + publishResponse.Body.Close() + } + + response := cns.PublishNetworkContainerResponse{ + Response: cns.Response{ + ReturnCode: returnCode, + Message: returnMessage, + }, + PublishErrorStr: publishErrorStr, + PublishStatusCode: publishStatusCode, + PublishResponseBody: publishResponseBody, + } + + err = service.Listener.Encode(w, &response) + log.Response(service.Name, response, response.Response.ReturnCode, ReturnCodeToString(response.Response.ReturnCode), err) +} + +// Unpublish Network Container by calling nmagent +func (service *HTTPRestService) unpublishNetworkContainer(w http.ResponseWriter, r *http.Request) { + log.Printf("[Azure-CNS] UnpublishNetworkContainer") + + var ( + err error + req cns.UnpublishNetworkContainerRequest + returnCode int + returnMessage string + unpublishResponse *http.Response + unpublishStatusCode int + unpublishResponseBody []byte + unpublishError error + unpublishErrorStr string + isNetworkJoined bool + ) + + err = service.Listener.Decode(w, r, &req) + log.Request(service.Name, &req, err) + if err != nil { + return + } + + switch r.Method { + case "POST": + // Join Network if not joined already + isNetworkJoined = service.isNetworkJoined(req.NetworkID) + if !isNetworkJoined { + unpublishResponse, unpublishError, err = service.joinNetwork(req.NetworkID, req.JoinNetworkURL) + if err == nil { + isNetworkJoined = true + } else { + returnMessage = err.Error() + returnCode = NetworkJoinFailed + } + } + + if isNetworkJoined { + // Unpublish Network Container + unpublishResponse, unpublishError = nmagentclient.UnpublishNetworkContainer( + req.NetworkContainerID, + req.DeleteNetworkContainerURL) + if unpublishError != nil || unpublishResponse.StatusCode != http.StatusOK { + returnMessage = fmt.Sprintf("Failed to unpublish Network Container: %s", req.NetworkContainerID) + returnCode = NetworkContainerUnpublishFailed + log.Errorf("[Azure-CNS] %s", returnMessage) + } + + if unpublishResponse != nil { + var errParse error + unpublishResponseBody, errParse = ioutil.ReadAll(unpublishResponse.Body) + if errParse != nil { + returnMessage = fmt.Sprintf("Failed to parse the unpublish body. Error: %v", errParse) + returnCode = UnexpectedError + log.Errorf("[Azure-CNS] %s", returnMessage) + } + + unpublishResponse.Body.Close() + } + } + default: + returnMessage = "UnpublishNetworkContainer API expects a POST" + returnCode = UnsupportedVerb + } + + if unpublishError != nil { + unpublishErrorStr = unpublishError.Error() + } + + if unpublishResponse != nil { + unpublishStatusCode = unpublishResponse.StatusCode + } + + response := cns.UnpublishNetworkContainerResponse{ + Response: cns.Response{ + ReturnCode: returnCode, + Message: returnMessage, + }, + UnpublishErrorStr: unpublishErrorStr, + UnpublishStatusCode: unpublishStatusCode, + UnpublishResponseBody: unpublishResponseBody, + } + + err = service.Listener.Encode(w, &response) + log.Response(service.Name, response, response.Response.ReturnCode, ReturnCodeToString(response.Response.ReturnCode), err) +} diff --git a/cns/restserver/restserver_test.go b/cns/restserver/restserver_test.go index 98806a21d3..3f10724869 100644 --- a/cns/restserver/restserver_test.go +++ b/cns/restserver/restserver_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "net/url" "os" + "strings" "testing" "github.com/Azure/azure-container-networking/cns" @@ -66,16 +67,23 @@ var ( } ) +const ( + nmagentEndpoint = "localhost:9000" +) + func getInterfaceInfo(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/xml") output, _ := xml.Marshal(hostQueryResponse) w.Write(output) } -func getContainerInfo(w http.ResponseWriter, r *http.Request) { +func nmagentHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.WriteHeader(http.StatusOK) - w.Write([]byte(hostQueryForProgrammedVersionResponse)) + + if strings.Contains(r.RequestURI, "networkContainers") { + w.Write([]byte(hostQueryForProgrammedVersionResponse)) + } } // Wraps the test run with service setup and teardown. @@ -109,7 +117,7 @@ func TestMain(m *testing.M) { mux = service.(*HTTPRestService).Listener.GetMux() // Setup mock nmagent server - u, err := url.Parse("tcp://localhost:9000") + u, err := url.Parse("tcp://" + nmagentEndpoint) if err != nil { fmt.Println(err.Error()) } @@ -120,7 +128,7 @@ func TestMain(m *testing.M) { } nmAgentServer.AddHandler("/getInterface", getInterfaceInfo) - nmAgentServer.AddHandler("machine/plugins/?comp=nmagent&type=NetworkManagement/interfaces/{interface}/networkContainers/{networkContainer}/authenticationToken/{authToken}/api-version/{version}", getContainerInfo) + nmAgentServer.AddHandler("/", nmagentHandler) err = nmAgentServer.Start(make(chan error, 1)) if err != nil { @@ -133,6 +141,7 @@ func TestMain(m *testing.M) { // Cleanup. service.Stop() + nmAgentServer.Stop() os.Exit(exitCode) } @@ -749,3 +758,80 @@ func TestGetNumOfCPUCores(t *testing.T) { fmt.Printf("getNumberOfCPUCores Responded with %+v\n", numOfCoresResponse) } } + +func TestPublishNCViaCNS(t *testing.T) { + fmt.Println("Test: publishNetworkContainer") + + var ( + body bytes.Buffer + resp cns.PublishNetworkContainerResponse + ) + + networkID := "vnet1" + networkContainerID := "ethWebApp" + joinNetworkURL := "http://" + nmagentEndpoint + "/dummyVnetURL" + createNetworkContainerURL := "http://" + nmagentEndpoint + "/networkContainers/dummyNCURL" + + publishNCRequest := &cns.PublishNetworkContainerRequest{ + NetworkID: networkID, + NetworkContainerID: networkContainerID, + JoinNetworkURL: joinNetworkURL, + CreateNetworkContainerURL: createNetworkContainerURL, + CreateNetworkContainerRequestBody: make([]byte, 0), + } + + json.NewEncoder(&body).Encode(publishNCRequest) + req, err := http.NewRequest(http.MethodPost, cns.PublishNetworkContainer, &body) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + err = decodeResponse(w, &resp) + if err != nil || resp.Response.ReturnCode != 0 { + t.Errorf("PublishNetworkContainer failed with response %+v Err:%+v", resp, err) + t.Fatal(err) + } + + fmt.Printf("PublishNetworkContainer succeded with response %+v, raw:%+v\n", resp, w.Body) +} + +func TestUnpublishNCViaCNS(t *testing.T) { + fmt.Println("Test: unpublishNetworkContainer") + + var ( + body bytes.Buffer + resp cns.UnpublishNetworkContainerResponse + ) + + networkID := "vnet1" + networkContainerID := "ethWebApp" + joinNetworkURL := "http://" + nmagentEndpoint + "/dummyVnetURL" + deleteNetworkContainerURL := "http://" + nmagentEndpoint + "/networkContainers/dummyNCURL" + + unpublishNCRequest := &cns.UnpublishNetworkContainerRequest{ + NetworkID: networkID, + NetworkContainerID: networkContainerID, + JoinNetworkURL: joinNetworkURL, + DeleteNetworkContainerURL: deleteNetworkContainerURL, + } + + json.NewEncoder(&body).Encode(unpublishNCRequest) + req, err := http.NewRequest(http.MethodPost, cns.UnpublishNetworkContainer, &body) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + err = decodeResponse(w, &resp) + if err != nil || resp.Response.ReturnCode != 0 { + t.Errorf("UnpublishNetworkContainer failed with response %+v Err:%+v", resp, err) + t.Fatal(err) + } + + fmt.Printf("UnpublishNetworkContainer succeded with response %+v, raw:%+v\n", resp, w.Body) +} diff --git a/cns/service/main.go b/cns/service/main.go index a7b2af83e6..3609aba856 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -152,6 +152,20 @@ var args = acn.ArgumentList{ Type: "bool", DefaultValue: true, }, + { + Name: acn.OptHttpConnectionTimeout, + Shorthand: acn.OptHttpConnectionTimeoutAlias, + Description: "Set HTTP connection timeout in seconds to be used by http client in CNS", + Type: "int", + DefaultValue: "5", + }, + { + Name: acn.OptHttpResponseHeaderTimeout, + Shorthand: acn.OptHttpResponseHeaderTimeoutAlias, + Description: "Set HTTP response header timeout in seconds to be used by http client in CNS", + Type: "int", + DefaultValue: "120", + }, } // Prints description and version information. @@ -179,6 +193,8 @@ func main() { vers := acn.GetArg(acn.OptVersion).(bool) createDefaultExtNetworkType := acn.GetArg(acn.OptCreateDefaultExtNetworkType).(string) telemetryEnabled := acn.GetArg(acn.OptTelemetry).(bool) + httpConnectionTimeout := acn.GetArg(acn.OptHttpConnectionTimeout).(int) + httpResponseHeaderTimeout := acn.GetArg(acn.OptHttpResponseHeaderTimeout).(int) if vers { printVersion() @@ -240,6 +256,8 @@ func main() { httpRestService.SetOption(acn.OptNetPluginPath, cniPath) httpRestService.SetOption(acn.OptNetPluginConfigFile, cniConfigFile) httpRestService.SetOption(acn.OptCreateDefaultExtNetworkType, createDefaultExtNetworkType) + httpRestService.SetOption(acn.OptHttpConnectionTimeout, httpConnectionTimeout) + httpRestService.SetOption(acn.OptHttpResponseHeaderTimeout, httpResponseHeaderTimeout) // Create default ext network if commandline option is set if len(strings.TrimSpace(createDefaultExtNetworkType)) > 0 { diff --git a/common/config.go b/common/config.go index c1c4b23488..7a272dc73e 100644 --- a/common/config.go +++ b/common/config.go @@ -79,4 +79,12 @@ const ( // Disable Telemetry OptTelemetry = "telemetry" OptTelemetryAlias = "dt" + + // HTTP connection timeout + OptHttpConnectionTimeout = "http-connection-timeout" + OptHttpConnectionTimeoutAlias = "httpcontimeout" + + // HTTP response header timeout + OptHttpResponseHeaderTimeout = "http-response-header-timeout" + OptHttpResponseHeaderTimeoutAlias = "httprespheadertimeout" ) diff --git a/common/utils.go b/common/utils.go index 7c5b7b7b6b..8b08a37ba3 100644 --- a/common/utils.go +++ b/common/utils.go @@ -71,6 +71,36 @@ type metadataWrapper struct { Metadata Metadata `json:"compute"` } +var ( + // Creating http client object to be reused instead of creating one every time. + // This helps make use of the cached tcp connections. + // Clients are safe for concurrent use by multiple goroutines. + httpClient *http.Client +) + +// InitHttpClient initializes the httpClient object +func InitHttpClient( + connectionTimeoutSec int, + responseHeaderTimeoutSec int) *http.Client { + log.Printf("[Utils] Initializing HTTP client with connection timeout: %d, response header timeout: %d", + connectionTimeoutSec, responseHeaderTimeoutSec) + httpClient = &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: time.Duration(connectionTimeoutSec) * time.Second, + }).DialContext, + ResponseHeaderTimeout: time.Duration(responseHeaderTimeoutSec) * time.Second, + }, + } + + return httpClient +} + +// GetHttpClient returns the singleton httpClient object +func GetHttpClient() *http.Client { + return httpClient +} + // LogNetworkInterfaces logs the host's network interfaces in the default namespace. func LogNetworkInterfaces() { interfaces, err := net.Interfaces()