diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 9fa17aaf80..b5794d8dd4 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -18,5 +18,6 @@ "UseHTTPS" : false, "TLSSubjectName" : "", "TLSCertificatePath" : "", - "TLSEndpoint" : "localhost:10091" + "TLSEndpoint" : "localhost:10091", + "WireserverIP": "168.63.129.16" } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 8ef4c60c3e..bd2f81b098 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -24,6 +24,7 @@ type CNSConfig struct { TLSSubjectName string TLSCertificatePath string TLSEndpoint string + WireserverIP string } type TelemetrySettings struct { diff --git a/cns/restserver/api.go b/cns/restserver/api.go index 5e9a9d79ff..3a46e7549a 100644 --- a/cns/restserver/api.go +++ b/cns/restserver/api.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net" "net/http" + "regexp" "runtime" "strings" @@ -1139,6 +1140,16 @@ func getAuthTokenFromCreateNetworkContainerURL( return strings.Split(strings.Split(createNetworkContainerURL, "authenticationToken/")[1], "/")[0] } +var rgx = regexp.MustCompile("^http[s]?://(.*?)/joinedVirtualNetworks.*?$") + +func extractHostFromJoinNetworkURL(url string) string { + submatches := rgx.FindStringSubmatch(url) + if len(submatches) != 2 { + return "" + } + return submatches[1] +} + // Publish Network Container by calling nmagent func (service *HTTPRestService) publishNetworkContainer(w http.ResponseWriter, r *http.Request) { logger.Printf("[Azure-CNS] PublishNetworkContainer") @@ -1199,8 +1210,15 @@ func (service *HTTPRestService) publishNetworkContainer(w http.ResponseWriter, r // Store ncGetVersionURL needed for calling NMAgent to check if vfp programming is completed for the NC primaryInterfaceIdentifier := getInterfaceIdFromCreateNetworkContainerURL(req.CreateNetworkContainerURL) authToken := getAuthTokenFromCreateNetworkContainerURL(req.CreateNetworkContainerURL) + + // we attempt to extract the wireserver IP to use from the request, otherwise default to the well-known IP. + hostIP := extractHostFromJoinNetworkURL(req.JoinNetworkURL) + if hostIP == "" { + hostIP = nmagentclient.WireserverIP + } + ncGetVersionURL := fmt.Sprintf(nmagentclient.GetNetworkContainerVersionURLFmt, - nmagentclient.WireserverIP, + hostIP, primaryInterfaceIdentifier, req.NetworkContainerID, authToken) diff --git a/cns/restserver/api_test.go b/cns/restserver/api_test.go index 29de9cfc6f..07551c71bc 100644 --- a/cns/restserver/api_test.go +++ b/cns/restserver/api_test.go @@ -553,6 +553,16 @@ func publishNCViaCNS(t *testing.T, fmt.Printf("PublishNetworkContainer succeded with response %+v, raw:%+v\n", resp, w.Body) } +func TestExtractHost(t *testing.T) { + joinURL := "http://127.0.0.1:9001/joinedVirtualNetworks/c9b8e695-2de1-11eb-bf54-000d3af666c8/api-version/1" + + host := extractHostFromJoinNetworkURL(joinURL) + expected := "127.0.0.1:9001" + if host != expected { + t.Fatalf("expected host %q, got %q", expected, host) + } +} + func TestUnpublishNCViaCNS(t *testing.T) { fmt.Println("Test: unpublishNetworkContainer") diff --git a/cns/service/main.go b/cns/service/main.go index 0262232be9..7eb274652b 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -404,6 +404,10 @@ func main() { configuration.SetCNSConfigDefaults(&cnsconfig) logger.Printf("[Azure CNS] Read config :%+v", cnsconfig) + if cnsconfig.WireserverIP != "" { + nmagentclient.WireserverIP = cnsconfig.WireserverIP + } + if cnsconfig.ChannelMode == cns.Managed { config.ChannelMode = cns.Managed privateEndpoint = cnsconfig.ManagedSettings.PrivateEndpoint