From 3714fbc5a0acb597ee54f0055b415c1a8050fe0d Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 16 Mar 2022 17:37:05 -0400 Subject: [PATCH 01/44] Add implementation for GetNetworkConfiguration Previously the NMAgent client did not have support for the GetNetworkConfiguration API call. This adds it and appropriate coverage. --- nmagent/client.go | 163 +++++++++++++++++++++++ nmagent/client_test.go | 288 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 451 insertions(+) create mode 100644 nmagent/client.go create mode 100644 nmagent/client_test.go diff --git a/nmagent/client.go b/nmagent/client.go new file mode 100644 index 0000000000..e717716668 --- /dev/null +++ b/nmagent/client.go @@ -0,0 +1,163 @@ +package nmagent + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/url" + "strings" + + "github.com/google/uuid" +) + +const ( + JoinNetworkPath string = "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + GetNetworkConfigPath string = "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/%s/api-version/1" +) + +// Error is a aberrent condition encountered when interacting with the NMAgent +// API +type Error struct { + Code int // the HTTP status code received +} + +func (e Error) Error() string { + return fmt.Sprintf("nmagent: http status %d", e.Code) +} + +type VirtualNetwork struct { + CNetSpace string `json:"cnetSpace"` + DefaultGateway string `json:"defaultGateway"` + DNSServers []string `json:"dnsServers"` + Subnets []Subnet `json:"subnets"` + VNetSpace string `json:"vnetSpace"` + VNetVersion string `json:"vnetVersion"` +} + +type Subnet struct { + AddressPrefix string `json:"addressPrefix"` + SubnetName string `json:"subnetName"` + Tags []Tag `json:"tags"` +} + +type Tag struct { + Name string `json:"name"` + Type string `json:"type"` // the type of the tag (e.g. "System" or "Custom") +} + +// Client is an agent for exchanging information with NMAgent +type Client struct { + HTTPClient *http.Client + + // config + Host string + Port string +} + +// JoinNetwork joins a node to a customer's virtual network +func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { + // we need to be a little defensive, because there is no bad request response + // from NMAgent + if _, err := uuid.Parse(networkID); err != nil { + return fmt.Errorf("bad network ID %q: %w", networkID, err) + } + + joinURL := &url.URL{ + Scheme: "https", + Host: net.JoinHostPort(c.Host, c.Port), + Path: fmt.Sprintf(JoinNetworkPath, networkID), + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL.String(), strings.NewReader("")) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + // TODO(timraymond): exponential backoff needed + for { + // check to see if the context is still alive + if err := ctx.Err(); err != nil { + return err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return fmt.Errorf("executing request: %w", err) + } + defer resp.Body.Close() + + // the response from NMAgent only contains the HTTP status code, so there is + // no need to parse it + + switch resp.StatusCode { + case http.StatusOK: + return nil + case http.StatusInternalServerError: + return Error{ + Code: http.StatusInternalServerError, + } + case http.StatusProcessing: + continue + default: + return nil + } + } +} + +// GetNetworkConfiguration retrieves the configuration of a customer's virtual +// network. Only subnets which have been delegated will be returned +func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (VirtualNetwork, error) { + path := &url.URL{ + Scheme: "https", + Host: net.JoinHostPort(c.Host, c.Port), + Path: fmt.Sprintf(GetNetworkConfigPath, vnetID), + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path.String(), strings.NewReader("")) + if err != nil { + return VirtualNetwork{}, fmt.Errorf("creating http request to %q: %w", path.String(), err) + } + + for { + // check to see if the context is dead + if err := ctx.Err(); err != nil { + return VirtualNetwork{}, err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return VirtualNetwork{}, fmt.Errorf("executing http request to %q: %w", path.String(), err) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + var out VirtualNetwork + err = json.NewDecoder(resp.Body).Decode(&out) + if err != nil { + return VirtualNetwork{}, fmt.Errorf("decoding json response for %q: %w", path.String(), err) + } + return out, nil + case http.StatusProcessing: + continue + default: + return VirtualNetwork{}, fmt.Errorf("unexpected HTTP status from NMAgent (%d): %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + } +} + +/* +func (c *Client) PutNetworkContainer(ctx context.Context) error { + return nil +} + +func (c *Client) DeleteNetworkContainer(ctx context.Context) error { + return nil +} + +func (c *Client) GetNmAgentSupportedApiURLFmt(ctx context.Context) error { + return nil +} +*/ diff --git a/nmagent/client_test.go b/nmagent/client_test.go new file mode 100644 index 0000000000..a4ed80e65d --- /dev/null +++ b/nmagent/client_test.go @@ -0,0 +1,288 @@ +package nmagent_test + +import ( + "context" + "dnc/nmagent" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/go-cmp/cmp" +) + +var _ http.RoundTripper = &TestTripper{} + +// TestTripper is a RoundTripper with a customizeable RoundTrip method for +// testing purposes +type TestTripper struct { + RoundTripF func(*http.Request) (*http.Response, error) +} + +func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripF(req) +} + +func TestNMAgentClientJoinNetwork(t *testing.T) { + joinNetTests := []struct { + name string + id string + exp string + respStatus int + shouldErr bool + }{ + { + "happy path", + "00000000-0000-0000-0000-000000000000", + "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + http.StatusOK, + false, + }, + { + "empty network ID", + "", + "", + http.StatusOK, // this shouldn't be checked + true, + }, + { + "malformed UUID", + "00000000-0000", + "", + http.StatusOK, // this shouldn't be checked + true, + }, + { + "internal error", + "00000000-0000-0000-0000-000000000000", + "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + http.StatusInternalServerError, + true, + }, + } + + for _, test := range joinNetTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // create a client + var got string + client := nmagent.Client{ + HTTPClient: &http.Client{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + rr.WriteHeader(test.respStatus) + return rr.Result(), nil + }, + }, + }, + } + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + // attempt to join network + // TODO(timraymond): need a more realistic network ID, I think + err := client.JoinNetwork(ctx, test.id) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if got != test.exp { + t.Error("received URL differs from expectation: got", got, "exp:", test.exp) + } + }) + } +} + +func TestNMAgentClientJoinNetworkRetry(t *testing.T) { + // we want to ensure that the client will automatically follow up with + // NMAgent, so we want to track the number of requests that it makes + invocations := 0 + exp := 10 + + client := nmagent.Client{ + HTTPClient: &http.Client{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if invocations < exp { + rr.WriteHeader(http.StatusProcessing) + invocations++ + } else { + rr.WriteHeader(http.StatusOK) + } + return rr.Result(), nil + }, + }, + }, + } + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + // attempt to join network + err := client.JoinNetwork(ctx, "00000000-0000-0000-0000-000000000000") + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if invocations != exp { + t.Error("client did not make the expected number of API calls: got:", invocations, "exp:", exp) + } +} + +func TestNMAgentGetNetworkConfig(t *testing.T) { + getTests := []struct { + name string + vnetID string + expURL string + expVNet nmagent.VirtualNetwork + shouldCall bool + shouldErr bool + }{ + { + "happy path", + "00000000-0000-0000-0000-000000000000", + "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + nmagent.VirtualNetwork{ + CNetSpace: "10.10.1.0/24", + DefaultGateway: "10.10.0.1", + DNSServers: []string{ + "1.1.1.1", + "1.0.0.1", + }, + Subnets: []nmagent.Subnet{}, + VNetSpace: "10.0.0.0/8", + VNetVersion: "2018", // TODO(timraymond): what's a real version look like? + }, + true, + false, + }, + } + + for _, test := range getTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var got string + client := &nmagent.Client{ + HTTPClient: &http.Client{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + got = req.URL.Path + rr.WriteHeader(http.StatusOK) + err := json.NewEncoder(rr).Encode(&test.expVNet) + if err != nil { + return nil, fmt.Errorf("encoding response: %w", err) + } + + return rr.Result(), nil + }, + }, + }, + } + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + gotVNet, err := client.GetNetworkConfiguration(ctx, test.vnetID) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if got != test.expURL && test.shouldCall { + t.Error("unexpected URL: got:", got, "exp:", test.expURL) + } + + if !cmp.Equal(gotVNet, test.expVNet) { + t.Error("received vnet differs from expected: diff:", cmp.Diff(gotVNet, test.expVNet)) + } + }) + } +} + +func TestNMAgentGetNetworkConfigRetry(t *testing.T) { + t.Parallel() + + count := 0 + exp := 10 + client := &nmagent.Client{ + HTTPClient: &http.Client{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if count < exp { + rr.WriteHeader(http.StatusProcessing) + count++ + } else { + rr.WriteHeader(http.StatusOK) + } + + // we still need a fake response + var out nmagent.VirtualNetwork + err := json.NewEncoder(rr).Encode(&out) + if err != nil { + return nil, err + } + + return rr.Result(), nil + }, + }, + }, + } + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + _, err := client.GetNetworkConfiguration(ctx, "00000000-0000-0000-0000-000000000000") + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if count != exp { + t.Error("unexpected number of API calls: exp:", exp, "got:", count) + } +} From e87cbe655c6ba9d7ce7e84949cbacb005754cd02 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 16 Mar 2022 18:29:41 -0400 Subject: [PATCH 02/44] Refactor retry loops to use shared function The cancellable retry was common enough that it made sense to extract it to a separate BackoffRetry function in internal. This made its functionality easier to test and reduced the number of tests necessary for each new endpoint --- nmagent/client.go | 76 +++++++++++++---------------- nmagent/internal/internal.go | 34 +++++++++++++ nmagent/internal/internal_test.go | 79 +++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 42 deletions(-) create mode 100644 nmagent/internal/internal.go create mode 100644 nmagent/internal/internal_test.go diff --git a/nmagent/client.go b/nmagent/client.go index e717716668..fba1b0a766 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -2,6 +2,7 @@ package nmagent import ( "context" + "dnc/nmagent/internal" "encoding/json" "fmt" "net" @@ -27,6 +28,12 @@ func (e Error) Error() string { return fmt.Sprintf("nmagent: http status %d", e.Code) } +// Temporary reports whether the error encountered from NMAgent should be +// considered temporary, and thus retriable +func (e Error) Temporary() bool { + return e.Code == http.StatusProcessing +} + type VirtualNetwork struct { CNetSpace string `json:"cnetSpace"` DefaultGateway string `json:"defaultGateway"` @@ -75,35 +82,20 @@ func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { return fmt.Errorf("creating request: %w", err) } - // TODO(timraymond): exponential backoff needed - for { - // check to see if the context is still alive - if err := ctx.Err(); err != nil { - return err - } - + err = internal.BackoffRetry(ctx, func() error { resp, err := c.HTTPClient.Do(req) if err != nil { return fmt.Errorf("executing request: %w", err) } defer resp.Body.Close() - // the response from NMAgent only contains the HTTP status code, so there is - // no need to parse it - - switch resp.StatusCode { - case http.StatusOK: - return nil - case http.StatusInternalServerError: - return Error{ - Code: http.StatusInternalServerError, - } - case http.StatusProcessing: - continue - default: - return nil + if resp.StatusCode != http.StatusOK { + return Error{resp.StatusCode} } - } + return nil + }) + + return err } // GetNetworkConfiguration retrieves the configuration of a customer's virtual @@ -115,37 +107,37 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (Vi Path: fmt.Sprintf(GetNetworkConfigPath, vnetID), } + var out VirtualNetwork + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path.String(), strings.NewReader("")) if err != nil { - return VirtualNetwork{}, fmt.Errorf("creating http request to %q: %w", path.String(), err) + return out, fmt.Errorf("creating http request to %q: %w", path.String(), err) } - for { - // check to see if the context is dead - if err := ctx.Err(); err != nil { - return VirtualNetwork{}, err - } - + err = internal.BackoffRetry(ctx, func() error { resp, err := c.HTTPClient.Do(req) if err != nil { - return VirtualNetwork{}, fmt.Errorf("executing http request to %q: %w", path.String(), err) + return fmt.Errorf("executing http request to %q: %w", path.String(), err) } defer resp.Body.Close() - switch resp.StatusCode { - case http.StatusOK: - var out VirtualNetwork - err = json.NewDecoder(resp.Body).Decode(&out) - if err != nil { - return VirtualNetwork{}, fmt.Errorf("decoding json response for %q: %w", path.String(), err) - } - return out, nil - case http.StatusProcessing: - continue - default: - return VirtualNetwork{}, fmt.Errorf("unexpected HTTP status from NMAgent (%d): %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + if resp.StatusCode != http.StatusOK { + return Error{resp.StatusCode} + } + + err = json.NewDecoder(resp.Body).Decode(&out) + if err != nil { + return fmt.Errorf("decoding json response for %q: %w", path.String(), err) } + + return nil + }) + + if err != nil { + // no need to wrap, as the retry wrapper is intended to be transparent + return out, err } + return out, nil } /* diff --git a/nmagent/internal/internal.go b/nmagent/internal/internal.go new file mode 100644 index 0000000000..be813672e2 --- /dev/null +++ b/nmagent/internal/internal.go @@ -0,0 +1,34 @@ +package internal + +import ( + "context" + "errors" +) + +type TemporaryError interface { + error + Temporary() bool +} + +// BackoffRetry implements cancellable exponential backoff of some arbitrary +// function +func BackoffRetry(ctx context.Context, run func() error) error { + for { + if err := ctx.Err(); err != nil { + return err + } + + err := run() + if err != nil { + // check to see if it's temporary + var tempErr TemporaryError + if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { + continue + } + + // since it's not temporary, it can't be retried, so... + return err + } + return nil + } +} diff --git a/nmagent/internal/internal_test.go b/nmagent/internal/internal_test.go new file mode 100644 index 0000000000..b456ecaf43 --- /dev/null +++ b/nmagent/internal/internal_test.go @@ -0,0 +1,79 @@ +package internal + +import ( + "context" + "errors" + "testing" +) + +type TestError struct{} + +func (t TestError) Error() string { + return "oh no!" +} + +func (t TestError) Temporary() bool { + return true +} + +func TestBackoffRetry(t *testing.T) { + got := 0 + exp := 10 + + ctx := context.Background() + + err := BackoffRetry(ctx, func() error { + if got < exp { + got++ + return TestError{} + } + return nil + }) + + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if got < exp { + t.Error("unexpected number of invocations: got:", got, "exp:", exp) + } +} + +func TestBackoffRetryWithCancel(t *testing.T) { + got := 0 + exp := 5 + total := 10 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := BackoffRetry(ctx, func() error { + got++ + if got >= exp { + cancel() + } + + if got < total { + return TestError{} + } + return nil + }) + + if err == nil { + t.Error("expected context cancellation error, but received none") + } + + if got != exp { + t.Error("unexpected number of iterations: exp:", exp, "got:", got) + } +} + +func TestBackoffRetryUnretriableError(t *testing.T) { + err := BackoffRetry(context.Background(), func() error { + return errors.New("boom") + }) + + if err == nil { + t.Fatal("expected an error, but none was returned") + } +} From 641388c2e1ddf0e9b7763ef203ae8d068ee5004d Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 16 Mar 2022 18:46:16 -0400 Subject: [PATCH 03/44] Slight re-org The client had enough extra stuff in it that it made sense to start separating things into different files --- nmagent/client.go | 36 ------------------------------------ nmagent/nmagent.go | 25 +++++++++++++++++++++++++ nmagent/responses.go | 21 +++++++++++++++++++++ 3 files changed, 46 insertions(+), 36 deletions(-) create mode 100644 nmagent/nmagent.go create mode 100644 nmagent/responses.go diff --git a/nmagent/client.go b/nmagent/client.go index fba1b0a766..840ef27c97 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -18,42 +18,6 @@ const ( GetNetworkConfigPath string = "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/%s/api-version/1" ) -// Error is a aberrent condition encountered when interacting with the NMAgent -// API -type Error struct { - Code int // the HTTP status code received -} - -func (e Error) Error() string { - return fmt.Sprintf("nmagent: http status %d", e.Code) -} - -// Temporary reports whether the error encountered from NMAgent should be -// considered temporary, and thus retriable -func (e Error) Temporary() bool { - return e.Code == http.StatusProcessing -} - -type VirtualNetwork struct { - CNetSpace string `json:"cnetSpace"` - DefaultGateway string `json:"defaultGateway"` - DNSServers []string `json:"dnsServers"` - Subnets []Subnet `json:"subnets"` - VNetSpace string `json:"vnetSpace"` - VNetVersion string `json:"vnetVersion"` -} - -type Subnet struct { - AddressPrefix string `json:"addressPrefix"` - SubnetName string `json:"subnetName"` - Tags []Tag `json:"tags"` -} - -type Tag struct { - Name string `json:"name"` - Type string `json:"type"` // the type of the tag (e.g. "System" or "Custom") -} - // Client is an agent for exchanging information with NMAgent type Client struct { HTTPClient *http.Client diff --git a/nmagent/nmagent.go b/nmagent/nmagent.go new file mode 100644 index 0000000000..f98183a84b --- /dev/null +++ b/nmagent/nmagent.go @@ -0,0 +1,25 @@ +package nmagent + +import ( + "fmt" + "net/http" +) + +// Error is a aberrent condition encountered when interacting with the NMAgent +// API +type Error struct { + Code int // the HTTP status code received +} + +func (e Error) Error() string { + return fmt.Sprintf("nmagent: http status %d", e.Code) +} + +// Temporary reports whether the error encountered from NMAgent should be +// considered temporary, and thus retriable +func (e Error) Temporary() bool { + // NMAgent will return a 102 (Processing) if the request is taking time to + // complete. These should be attempted again. As such, it's the only + // retriable status code + return e.Code == http.StatusProcessing +} diff --git a/nmagent/responses.go b/nmagent/responses.go new file mode 100644 index 0000000000..48705eff17 --- /dev/null +++ b/nmagent/responses.go @@ -0,0 +1,21 @@ +package nmagent + +type VirtualNetwork struct { + CNetSpace string `json:"cnetSpace"` + DefaultGateway string `json:"defaultGateway"` + DNSServers []string `json:"dnsServers"` + Subnets []Subnet `json:"subnets"` + VNetSpace string `json:"vnetSpace"` + VNetVersion string `json:"vnetVersion"` +} + +type Subnet struct { + AddressPrefix string `json:"addressPrefix"` + SubnetName string `json:"subnetName"` + Tags []Tag `json:"tags"` +} + +type Tag struct { + Name string `json:"name"` + Type string `json:"type"` // the type of the tag (e.g. "System" or "Custom") +} From b9ef256d92090c3c90890622df426d038007e9e7 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 16 Mar 2022 20:14:12 -0400 Subject: [PATCH 04/44] Add retries for Unauthorized responses In the original logic, unauthorized responses are treated as temporary for a specific period of time. This makes the nmagent.Error consider Unauthorized responses as temporary for a configurable time. Given that BackoffRetry cares only whether or not an error is temporary, this naturally causes them to be retried. Additional coverage was added for these scenarios as well. --- nmagent/client.go | 23 ++++++++- nmagent/client_test.go | 100 ++++++++++++++++++++++++++++++++++++++++ nmagent/nmagent.go | 15 ++++-- nmagent/nmagent_test.go | 79 +++++++++++++++++++++++++++++++ 4 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 nmagent/nmagent_test.go diff --git a/nmagent/client.go b/nmagent/client.go index 840ef27c97..928458da6b 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/google/uuid" ) @@ -25,10 +26,16 @@ type Client struct { // config Host string Port string + + // UnauthorizedGracePeriod is the amount of time Unauthorized responses from + // NMAgent will be tolerated and retried + UnauthorizedGracePeriod time.Duration } // JoinNetwork joins a node to a customer's virtual network func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { + requestStart := time.Now() + // we need to be a little defensive, because there is no bad request response // from NMAgent if _, err := uuid.Parse(networkID); err != nil { @@ -54,7 +61,7 @@ func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return Error{resp.StatusCode} + return c.error(time.Since(requestStart), resp.StatusCode) } return nil }) @@ -65,6 +72,8 @@ func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { // GetNetworkConfiguration retrieves the configuration of a customer's virtual // network. Only subnets which have been delegated will be returned func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (VirtualNetwork, error) { + requestStart := time.Now() + path := &url.URL{ Scheme: "https", Host: net.JoinHostPort(c.Host, c.Port), @@ -86,7 +95,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (Vi defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return Error{resp.StatusCode} + return c.error(time.Since(requestStart), resp.StatusCode) } err = json.NewDecoder(resp.Body).Decode(&out) @@ -117,3 +126,13 @@ func (c *Client) GetNmAgentSupportedApiURLFmt(ctx context.Context) error { return nil } */ + +// error constructs a NMAgent error while providing some information configured +// at instantiation +func (c *Client) error(runtime time.Duration, code int) error { + return Error{ + Runtime: runtime, + Limit: c.UnauthorizedGracePeriod, + Code: code, + } +} diff --git a/nmagent/client_test.go b/nmagent/client_test.go index a4ed80e65d..659203f391 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/google/go-cmp/cmp" ) @@ -154,6 +155,54 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { } } +// TODO(timraymond): this is super repetitive (see the retry test) +func TestNMAgentClientJoinNetworkUnauthorized(t *testing.T) { + t.Parallel() + + // we want to ensure that the client will automatically follow up with + // NMAgent, so we want to track the number of requests that it makes + invocations := 0 + exp := 10 + + client := nmagent.Client{ + UnauthorizedGracePeriod: 1 * time.Minute, + HTTPClient: &http.Client{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if invocations < exp { + rr.WriteHeader(http.StatusUnauthorized) + invocations++ + } else { + rr.WriteHeader(http.StatusOK) + } + return rr.Result(), nil + }, + }, + }, + } + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + // attempt to join network + err := client.JoinNetwork(ctx, "00000000-0000-0000-0000-000000000000") + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if invocations != exp { + t.Error("client did not make the expected number of API calls: got:", invocations, "exp:", exp) + } +} + func TestNMAgentGetNetworkConfig(t *testing.T) { getTests := []struct { name string @@ -286,3 +335,54 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { t.Error("unexpected number of API calls: exp:", exp, "got:", count) } } + +func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { + t.Parallel() + + count := 0 + exp := 10 + client := &nmagent.Client{ + UnauthorizedGracePeriod: 1 * time.Minute, + HTTPClient: &http.Client{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if count < exp { + rr.WriteHeader(http.StatusUnauthorized) + count++ + } else { + rr.WriteHeader(http.StatusOK) + } + + // we still need a fake response + var out nmagent.VirtualNetwork + err := json.NewEncoder(rr).Encode(&out) + if err != nil { + return nil, err + } + + return rr.Result(), nil + }, + }, + }, + } + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + _, err := client.GetNetworkConfiguration(ctx, "00000000-0000-0000-0000-000000000000") + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if count != exp { + t.Error("unexpected number of API calls: exp:", exp, "got:", count) + } +} diff --git a/nmagent/nmagent.go b/nmagent/nmagent.go index f98183a84b..23a1007228 100644 --- a/nmagent/nmagent.go +++ b/nmagent/nmagent.go @@ -3,12 +3,15 @@ package nmagent import ( "fmt" "net/http" + "time" ) // Error is a aberrent condition encountered when interacting with the NMAgent // API type Error struct { - Code int // the HTTP status code received + Runtime time.Duration // the amount of time the operation has been running + Limit time.Duration // the maximum amount of time the operation can run for + Code int // the HTTP status code received } func (e Error) Error() string { @@ -18,8 +21,14 @@ func (e Error) Error() string { // Temporary reports whether the error encountered from NMAgent should be // considered temporary, and thus retriable func (e Error) Temporary() bool { + // We consider Unauthorized responses from NMAgent to be temporary for a + // certain period of time. This is to allow for situations where an + // authorization token may not yet be available + if e.Code == http.StatusUnauthorized { + return e.Runtime < e.Limit + } + // NMAgent will return a 102 (Processing) if the request is taking time to - // complete. These should be attempted again. As such, it's the only - // retriable status code + // complete. These should be attempted again. return e.Code == http.StatusProcessing } diff --git a/nmagent/nmagent_test.go b/nmagent/nmagent_test.go new file mode 100644 index 0000000000..f77cb40f01 --- /dev/null +++ b/nmagent/nmagent_test.go @@ -0,0 +1,79 @@ +package nmagent_test + +import ( + "dnc/nmagent" + "net/http" + "testing" + "time" +) + +func TestErrorTemp(t *testing.T) { + errorTests := []struct { + name string + err nmagent.Error + shouldTemp bool + }{ + { + "regular", + nmagent.Error{ + Code: http.StatusInternalServerError, + }, + false, + }, + { + "processing", + nmagent.Error{ + Code: http.StatusProcessing, + }, + true, + }, + { + "unauthorized temporary", + nmagent.Error{ + Runtime: 30 * time.Second, + Limit: 1 * time.Minute, + Code: http.StatusUnauthorized, + }, + true, + }, + { + "unauthorized permanent", + nmagent.Error{ + Runtime: 2 * time.Minute, + Limit: 1 * time.Minute, + Code: http.StatusUnauthorized, + }, + false, + }, + { + "unauthorized zero values", + nmagent.Error{ + Code: http.StatusUnauthorized, + }, + false, + }, + { + "unauthorized zero limit", + nmagent.Error{ + Runtime: 2 * time.Minute, + Code: http.StatusUnauthorized, + }, + false, + }, + } + + for _, test := range errorTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + if test.err.Temporary() && !test.shouldTemp { + t.Fatal("test was temporary and not expected to be") + } + + if !test.err.Temporary() && test.shouldTemp { + t.Fatal("test was not temporary but expected to be") + } + }) + } +} From 491171d6c15ac25cc26202fd8924b09380a5cb09 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Thu, 17 Mar 2022 16:21:23 -0400 Subject: [PATCH 05/44] Add a WireserverTransport This deals with all the quirks of proxying requests to NMAgent through Wireserver, without spreading that concern through the NMAgent client itself. --- nmagent/internal/internal.go | 115 +++++++++ nmagent/internal/internal_test.go | 379 ++++++++++++++++++++++++++++++ 2 files changed, 494 insertions(+) diff --git a/nmagent/internal/internal.go b/nmagent/internal/internal.go index be813672e2..b8b681aebc 100644 --- a/nmagent/internal/internal.go +++ b/nmagent/internal/internal.go @@ -1,10 +1,125 @@ package internal import ( + "bytes" "context" + "encoding/json" "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" ) +const WirePrefix string = "/machine/plugins/?comp=nmagent&type=" + +var _ http.RoundTripper = &WireserverTransport{} + +// WireserverResponse represents a raw response from Wireserver +type WireserverResponse map[string]json.RawMessage + +func (w WireserverResponse) StatusCode() (int, error) { + if status, ok := w["httpStatusCode"]; ok { + var statusStr string + err := json.Unmarshal(status, &statusStr) + if err != nil { + return 0, fmt.Errorf("unmarshaling httpStatusCode from Wireserver: %w", err) + } + + if code, err := strconv.Atoi(statusStr); err != nil { + return code, fmt.Errorf("parsing http status code from wireserver: %w", err) + } else { + return code, nil + } + } + return 0, fmt.Errorf("no httpStatusCode property returned in Wireserver response") +} + +// WireserverTransport is an http.RoundTripper that applies transformation +// rules to inbound requests necessary to make them compatible with Wireserver +type WireserverTransport struct { + Transport http.RoundTripper +} + +func (w *WireserverTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // the original path of the request must be prefixed with wireserver's path + origPath := req.URL.Path + path := WirePrefix + if req.URL.Path != "" { + path += req.URL.Path[1:] + } + + // the query string from the request must have its constituent parts (?,=,&) + // transformed to slashes and appended to the query + if req.URL.RawQuery != "" { + query := req.URL.RawQuery + query = strings.ReplaceAll(query, "?", "/") + query = strings.ReplaceAll(query, "=", "/") + query = strings.ReplaceAll(query, "&", "/") + path += "/" + query + } + + req.URL.Path = path + // ensure that nothing has changed from the caller's perspective by resetting + // the URL + defer func() { + req.URL.Path = origPath + }() + + // wireserver cannot tolerate PUT requests, so it's necessary to transform those to POSTs + if req.Method == http.MethodPut { + req.Method = http.MethodPost + defer func() { + req.Method = http.MethodPut + }() + } + + // all POST requests (and by extension, PUT) must have a non-nil body + if req.Method == http.MethodPost && req.Body == nil { + req.Body = io.NopCloser(strings.NewReader("")) + } + + // execute the request to the downstream transport + resp, err := w.Transport.RoundTrip(req) + if err != nil { + return resp, err + } + // we want to close this because we're going to replace it + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return resp, nil + } + + // correct the HTTP status returned from wireserver + var wsResp WireserverResponse + err = json.NewDecoder(resp.Body).Decode(&wsResp) + if err != nil { + return resp, fmt.Errorf("decoding json response from wireserver: %w", err) + } + + // set the response status code with the *real* status code + realCode, err := wsResp.StatusCode() + if err != nil { + return resp, fmt.Errorf("retrieving status code from wireserver response: %w", err) + } + + resp.StatusCode = realCode + + // re-encode the body and re-attach it to the response + delete(wsResp, "httpStatusCode") // TODO(timraymond): concern of the response + + body, err := json.Marshal(wsResp) + if err != nil { + return resp, fmt.Errorf("re-encoding json response from wireserver: %w", err) + } + + resp.Body = io.NopCloser(bytes.NewReader(body)) + + return resp, nil +} + type TemporaryError interface { error Temporary() bool diff --git a/nmagent/internal/internal_test.go b/nmagent/internal/internal_test.go index b456ecaf43..f245ed261a 100644 --- a/nmagent/internal/internal_test.go +++ b/nmagent/internal/internal_test.go @@ -2,8 +2,14 @@ package internal import ( "context" + "encoding/json" "errors" + "fmt" + "net/http" + "net/http/httptest" "testing" + + "github.com/google/go-cmp/cmp" ) type TestError struct{} @@ -77,3 +83,376 @@ func TestBackoffRetryUnretriableError(t *testing.T) { t.Fatal("expected an error, but none was returned") } } + +type TestTripper struct { + // TODO(timraymond): this entire struct is duplicated + RoundTripF func(*http.Request) (*http.Response, error) +} + +func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripF(req) +} + +func TestWireserverTransportPathTransform(t *testing.T) { + // Wireserver introduces specific rules on how requests should be + // transformed. This test ensures we got those correct. + + pathTests := []struct { + name string + method string + sub string + exp string + }{ + { + "happy path", + http.MethodGet, + "/test/path", + "/machine/plugins/?comp=nmagent&type=test/path", + }, + { + "empty", + http.MethodGet, + "", + "/machine/plugins/?comp=nmagent&type=", + }, + { + "monopath", + http.MethodGet, + "/foo", + "/machine/plugins/?comp=nmagent&type=foo", + }, + } + + for _, test := range pathTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(r *http.Request) (*http.Response, error) { + got = r.URL.Path + rr := httptest.NewRecorder() + rr.WriteHeader(http.StatusOK) + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + return rr.Result(), nil + }, + }, + }, + } + + // execute + + req, err := http.NewRequest(test.method, test.sub, nil) + if err != nil { + t.Fatal("error creating new request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("unexpected error submitting request: err:", err) + } + + // assert + if got != test.exp { + t.Error("received path differs from expectation: exp:", test.exp, "got:", got) + } + }) + } +} + +func TestWireserverTransportStatusTransform(t *testing.T) { + // Wireserver only responds with 200 or 400 and embeds the actual status code + // in JSON. The Transport should correct this and return the actual status as + // an actual status + + statusTests := []struct { + name string + response map[string]interface{} + expBody map[string]interface{} + expStatus int + }{ + { + "401", + map[string]interface{}{ + "httpStatusCode": "401", + }, + map[string]interface{}{}, + http.StatusUnauthorized, + }, + { + "200 with body", + map[string]interface{}{ + "httpStatusCode": "200", + "some": "data", + }, + map[string]interface{}{ + "some": "data", + }, + http.StatusOK, + }, + } + + for _, test := range statusTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(r *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + // mimic Wireserver handing back a 200 regardless: + rr.WriteHeader(http.StatusOK) + + err := json.NewEncoder(rr).Encode(&test.response) + if err != nil { + return nil, fmt.Errorf("encoding json response: %w", err) + } + + return rr.Result(), nil + }, + }, + }, + } + + // execute + + req, err := http.NewRequest(http.MethodGet, "/test/path", nil) + if err != nil { + t.Fatal("error creating new request: err:", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal("unexpected error submitting request: err:", err) + } + defer resp.Body.Close() + + // assert + gotStatus := resp.StatusCode + if gotStatus != test.expStatus { + t.Errorf("status codes differ: exp: (%d) %s: got (%d): %s", test.expStatus, http.StatusText(test.expStatus), gotStatus, http.StatusText(gotStatus)) + } + + var gotBody map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&gotBody) + if err != nil { + t.Fatal("unexpected error decoding json body: err:", err) + } + + if !cmp.Equal(test.expBody, gotBody) { + t.Error("received body differs from expected: diff:", cmp.Diff(test.expBody, gotBody)) + } + }) + } +} + +func TestWireserverTransportPutPost(t *testing.T) { + // wireserver can't tolerate PUT requests, so they must be transformed to POSTs + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.Method + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + req, err := http.NewRequest(http.MethodPut, "/test/path", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + exp := http.MethodPost + if got != exp { + t.Error("unexpected status: exp:", exp, "got:", got) + } +} + +func TestWireserverTransportPostBody(t *testing.T) { + // all PUT and POST requests must have an empty string body + t.Parallel() + + bodyIsNil := false + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + bodyIsNil = req.Body == nil + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + // PUT + req, err := http.NewRequest(http.MethodPut, "/test/path", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + if bodyIsNil { + t.Error("downstream request body to wireserver was nil, but not expected to be") + } + + // POST + req, err = http.NewRequest(http.MethodPost, "/test/path", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + if bodyIsNil { + t.Error("downstream request body to wireserver was nil, but not expected to be") + } +} + +func TestWireserverTransportQuery(t *testing.T) { + // the query string must have its constituent parts converted to slashes and + // appended to the path + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + req, err := http.NewRequest(http.MethodPut, "/test/path?api-version=1234&foo=bar", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + exp := "/machine/plugins/?comp=nmagent&type=test/path/api-version/1234/foo/bar" + if got != exp { + t.Error("received request differs from expectation: got:", got, "want:", exp) + } +} + +func TestWireserverResponse(t *testing.T) { + wsRespTests := []struct { + name string + resp string + exp int + shouldErr bool + }{ + { + "empty", + "{}", + 0, + true, + }, + { + "happy path", + `{ + "httpStatusCode": "401" + }`, + 401, + false, + }, + { + "missing code", + `{ + "httpStatusCode": "" + }`, + 0, + true, + }, + { + "other stuff", + `{ + "httpStatusCode": "201", + "other": "stuff" + }`, + 201, + false, + }, + { + "not a string", + `{ + "httpStatusCode": 201, + "other": "stuff" + }`, + 0, + true, + }, + { + "processing", + `{ + "httpStatusCode": "102", + "other": "stuff" + }`, + 102, + false, + }, + } + + for _, test := range wsRespTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var resp WireserverResponse + err := json.Unmarshal([]byte(test.resp), &resp) + if err != nil { + t.Fatal("unexpected unmarshaling error: err:", err) + } + + got, err := resp.StatusCode() + if err != nil && !test.shouldErr { + t.Fatal("unexpected error retrieving status code: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("no error received when one was expected") + } + + if got != test.exp { + t.Error("received incorrect code: got:", got, "want:", test.exp) + } + }) + } +} From b9dcee340bbfe06fabb50ff3d92e08a39876f6b7 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Thu, 17 Mar 2022 16:25:26 -0400 Subject: [PATCH 06/44] Reorganize the nmagent internal package The wireserver transport became big enough to warrant its own file --- nmagent/internal/internal.go | 115 --------- nmagent/internal/internal_test.go | 379 --------------------------- nmagent/internal/wireserver.go | 123 +++++++++ nmagent/internal/wireserver_test.go | 384 ++++++++++++++++++++++++++++ 4 files changed, 507 insertions(+), 494 deletions(-) create mode 100644 nmagent/internal/wireserver.go create mode 100644 nmagent/internal/wireserver_test.go diff --git a/nmagent/internal/internal.go b/nmagent/internal/internal.go index b8b681aebc..be813672e2 100644 --- a/nmagent/internal/internal.go +++ b/nmagent/internal/internal.go @@ -1,125 +1,10 @@ package internal import ( - "bytes" "context" - "encoding/json" "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" ) -const WirePrefix string = "/machine/plugins/?comp=nmagent&type=" - -var _ http.RoundTripper = &WireserverTransport{} - -// WireserverResponse represents a raw response from Wireserver -type WireserverResponse map[string]json.RawMessage - -func (w WireserverResponse) StatusCode() (int, error) { - if status, ok := w["httpStatusCode"]; ok { - var statusStr string - err := json.Unmarshal(status, &statusStr) - if err != nil { - return 0, fmt.Errorf("unmarshaling httpStatusCode from Wireserver: %w", err) - } - - if code, err := strconv.Atoi(statusStr); err != nil { - return code, fmt.Errorf("parsing http status code from wireserver: %w", err) - } else { - return code, nil - } - } - return 0, fmt.Errorf("no httpStatusCode property returned in Wireserver response") -} - -// WireserverTransport is an http.RoundTripper that applies transformation -// rules to inbound requests necessary to make them compatible with Wireserver -type WireserverTransport struct { - Transport http.RoundTripper -} - -func (w *WireserverTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // the original path of the request must be prefixed with wireserver's path - origPath := req.URL.Path - path := WirePrefix - if req.URL.Path != "" { - path += req.URL.Path[1:] - } - - // the query string from the request must have its constituent parts (?,=,&) - // transformed to slashes and appended to the query - if req.URL.RawQuery != "" { - query := req.URL.RawQuery - query = strings.ReplaceAll(query, "?", "/") - query = strings.ReplaceAll(query, "=", "/") - query = strings.ReplaceAll(query, "&", "/") - path += "/" + query - } - - req.URL.Path = path - // ensure that nothing has changed from the caller's perspective by resetting - // the URL - defer func() { - req.URL.Path = origPath - }() - - // wireserver cannot tolerate PUT requests, so it's necessary to transform those to POSTs - if req.Method == http.MethodPut { - req.Method = http.MethodPost - defer func() { - req.Method = http.MethodPut - }() - } - - // all POST requests (and by extension, PUT) must have a non-nil body - if req.Method == http.MethodPost && req.Body == nil { - req.Body = io.NopCloser(strings.NewReader("")) - } - - // execute the request to the downstream transport - resp, err := w.Transport.RoundTrip(req) - if err != nil { - return resp, err - } - // we want to close this because we're going to replace it - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return resp, nil - } - - // correct the HTTP status returned from wireserver - var wsResp WireserverResponse - err = json.NewDecoder(resp.Body).Decode(&wsResp) - if err != nil { - return resp, fmt.Errorf("decoding json response from wireserver: %w", err) - } - - // set the response status code with the *real* status code - realCode, err := wsResp.StatusCode() - if err != nil { - return resp, fmt.Errorf("retrieving status code from wireserver response: %w", err) - } - - resp.StatusCode = realCode - - // re-encode the body and re-attach it to the response - delete(wsResp, "httpStatusCode") // TODO(timraymond): concern of the response - - body, err := json.Marshal(wsResp) - if err != nil { - return resp, fmt.Errorf("re-encoding json response from wireserver: %w", err) - } - - resp.Body = io.NopCloser(bytes.NewReader(body)) - - return resp, nil -} - type TemporaryError interface { error Temporary() bool diff --git a/nmagent/internal/internal_test.go b/nmagent/internal/internal_test.go index f245ed261a..b456ecaf43 100644 --- a/nmagent/internal/internal_test.go +++ b/nmagent/internal/internal_test.go @@ -2,14 +2,8 @@ package internal import ( "context" - "encoding/json" "errors" - "fmt" - "net/http" - "net/http/httptest" "testing" - - "github.com/google/go-cmp/cmp" ) type TestError struct{} @@ -83,376 +77,3 @@ func TestBackoffRetryUnretriableError(t *testing.T) { t.Fatal("expected an error, but none was returned") } } - -type TestTripper struct { - // TODO(timraymond): this entire struct is duplicated - RoundTripF func(*http.Request) (*http.Response, error) -} - -func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return t.RoundTripF(req) -} - -func TestWireserverTransportPathTransform(t *testing.T) { - // Wireserver introduces specific rules on how requests should be - // transformed. This test ensures we got those correct. - - pathTests := []struct { - name string - method string - sub string - exp string - }{ - { - "happy path", - http.MethodGet, - "/test/path", - "/machine/plugins/?comp=nmagent&type=test/path", - }, - { - "empty", - http.MethodGet, - "", - "/machine/plugins/?comp=nmagent&type=", - }, - { - "monopath", - http.MethodGet, - "/foo", - "/machine/plugins/?comp=nmagent&type=foo", - }, - } - - for _, test := range pathTests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - var got string - client := &http.Client{ - Transport: &WireserverTransport{ - Transport: &TestTripper{ - RoundTripF: func(r *http.Request) (*http.Response, error) { - got = r.URL.Path - rr := httptest.NewRecorder() - rr.WriteHeader(http.StatusOK) - rr.Write([]byte(`{"httpStatusCode": "200"}`)) - return rr.Result(), nil - }, - }, - }, - } - - // execute - - req, err := http.NewRequest(test.method, test.sub, nil) - if err != nil { - t.Fatal("error creating new request: err:", err) - } - - _, err = client.Do(req) - if err != nil { - t.Fatal("unexpected error submitting request: err:", err) - } - - // assert - if got != test.exp { - t.Error("received path differs from expectation: exp:", test.exp, "got:", got) - } - }) - } -} - -func TestWireserverTransportStatusTransform(t *testing.T) { - // Wireserver only responds with 200 or 400 and embeds the actual status code - // in JSON. The Transport should correct this and return the actual status as - // an actual status - - statusTests := []struct { - name string - response map[string]interface{} - expBody map[string]interface{} - expStatus int - }{ - { - "401", - map[string]interface{}{ - "httpStatusCode": "401", - }, - map[string]interface{}{}, - http.StatusUnauthorized, - }, - { - "200 with body", - map[string]interface{}{ - "httpStatusCode": "200", - "some": "data", - }, - map[string]interface{}{ - "some": "data", - }, - http.StatusOK, - }, - } - - for _, test := range statusTests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - client := &http.Client{ - Transport: &WireserverTransport{ - Transport: &TestTripper{ - RoundTripF: func(r *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - // mimic Wireserver handing back a 200 regardless: - rr.WriteHeader(http.StatusOK) - - err := json.NewEncoder(rr).Encode(&test.response) - if err != nil { - return nil, fmt.Errorf("encoding json response: %w", err) - } - - return rr.Result(), nil - }, - }, - }, - } - - // execute - - req, err := http.NewRequest(http.MethodGet, "/test/path", nil) - if err != nil { - t.Fatal("error creating new request: err:", err) - } - - resp, err := client.Do(req) - if err != nil { - t.Fatal("unexpected error submitting request: err:", err) - } - defer resp.Body.Close() - - // assert - gotStatus := resp.StatusCode - if gotStatus != test.expStatus { - t.Errorf("status codes differ: exp: (%d) %s: got (%d): %s", test.expStatus, http.StatusText(test.expStatus), gotStatus, http.StatusText(gotStatus)) - } - - var gotBody map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&gotBody) - if err != nil { - t.Fatal("unexpected error decoding json body: err:", err) - } - - if !cmp.Equal(test.expBody, gotBody) { - t.Error("received body differs from expected: diff:", cmp.Diff(test.expBody, gotBody)) - } - }) - } -} - -func TestWireserverTransportPutPost(t *testing.T) { - // wireserver can't tolerate PUT requests, so they must be transformed to POSTs - t.Parallel() - - var got string - client := &http.Client{ - Transport: &WireserverTransport{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - got = req.Method - rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) - rr.WriteHeader(http.StatusOK) - return rr.Result(), nil - }, - }, - }, - } - - req, err := http.NewRequest(http.MethodPut, "/test/path", nil) - if err != nil { - t.Fatal("unexpected error creating http request: err:", err) - } - - _, err = client.Do(req) - if err != nil { - t.Fatal("error submitting request: err:", err) - } - - exp := http.MethodPost - if got != exp { - t.Error("unexpected status: exp:", exp, "got:", got) - } -} - -func TestWireserverTransportPostBody(t *testing.T) { - // all PUT and POST requests must have an empty string body - t.Parallel() - - bodyIsNil := false - client := &http.Client{ - Transport: &WireserverTransport{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - bodyIsNil = req.Body == nil - rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) - rr.WriteHeader(http.StatusOK) - return rr.Result(), nil - }, - }, - }, - } - - // PUT - req, err := http.NewRequest(http.MethodPut, "/test/path", nil) - if err != nil { - t.Fatal("unexpected error creating http request: err:", err) - } - - _, err = client.Do(req) - if err != nil { - t.Fatal("error submitting request: err:", err) - } - - if bodyIsNil { - t.Error("downstream request body to wireserver was nil, but not expected to be") - } - - // POST - req, err = http.NewRequest(http.MethodPost, "/test/path", nil) - if err != nil { - t.Fatal("unexpected error creating http request: err:", err) - } - - _, err = client.Do(req) - if err != nil { - t.Fatal("error submitting request: err:", err) - } - - if bodyIsNil { - t.Error("downstream request body to wireserver was nil, but not expected to be") - } -} - -func TestWireserverTransportQuery(t *testing.T) { - // the query string must have its constituent parts converted to slashes and - // appended to the path - t.Parallel() - - var got string - client := &http.Client{ - Transport: &WireserverTransport{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - got = req.URL.Path - rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) - rr.WriteHeader(http.StatusOK) - return rr.Result(), nil - }, - }, - }, - } - - req, err := http.NewRequest(http.MethodPut, "/test/path?api-version=1234&foo=bar", nil) - if err != nil { - t.Fatal("unexpected error creating http request: err:", err) - } - - _, err = client.Do(req) - if err != nil { - t.Fatal("error submitting request: err:", err) - } - - exp := "/machine/plugins/?comp=nmagent&type=test/path/api-version/1234/foo/bar" - if got != exp { - t.Error("received request differs from expectation: got:", got, "want:", exp) - } -} - -func TestWireserverResponse(t *testing.T) { - wsRespTests := []struct { - name string - resp string - exp int - shouldErr bool - }{ - { - "empty", - "{}", - 0, - true, - }, - { - "happy path", - `{ - "httpStatusCode": "401" - }`, - 401, - false, - }, - { - "missing code", - `{ - "httpStatusCode": "" - }`, - 0, - true, - }, - { - "other stuff", - `{ - "httpStatusCode": "201", - "other": "stuff" - }`, - 201, - false, - }, - { - "not a string", - `{ - "httpStatusCode": 201, - "other": "stuff" - }`, - 0, - true, - }, - { - "processing", - `{ - "httpStatusCode": "102", - "other": "stuff" - }`, - 102, - false, - }, - } - - for _, test := range wsRespTests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - var resp WireserverResponse - err := json.Unmarshal([]byte(test.resp), &resp) - if err != nil { - t.Fatal("unexpected unmarshaling error: err:", err) - } - - got, err := resp.StatusCode() - if err != nil && !test.shouldErr { - t.Fatal("unexpected error retrieving status code: err:", err) - } - - if err == nil && test.shouldErr { - t.Fatal("no error received when one was expected") - } - - if got != test.exp { - t.Error("received incorrect code: got:", got, "want:", test.exp) - } - }) - } -} diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go new file mode 100644 index 0000000000..d3d6a96a4a --- /dev/null +++ b/nmagent/internal/wireserver.go @@ -0,0 +1,123 @@ +package internal + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" +) + +const WirePrefix string = "/machine/plugins/?comp=nmagent&type=" + +var _ http.RoundTripper = &WireserverTransport{} + +// WireserverResponse represents a raw response from Wireserver +type WireserverResponse map[string]json.RawMessage + +// StatusCode extracts the embedded HTTP status code from the response from Wireserver +func (w WireserverResponse) StatusCode() (int, error) { + if status, ok := w["httpStatusCode"]; ok { + var statusStr string + err := json.Unmarshal(status, &statusStr) + if err != nil { + return 0, fmt.Errorf("unmarshaling httpStatusCode from Wireserver: %w", err) + } + + if code, err := strconv.Atoi(statusStr); err != nil { + return code, fmt.Errorf("parsing http status code from wireserver: %w", err) + } else { + return code, nil + } + } + return 0, fmt.Errorf("no httpStatusCode property returned in Wireserver response") +} + +// WireserverTransport is an http.RoundTripper that applies transformation +// rules to inbound requests necessary to make them compatible with Wireserver +type WireserverTransport struct { + Transport http.RoundTripper +} + +// RoundTrip executes arbitrary HTTP requests against Wireserver while applying +// the necessary transformation rules to make such requests acceptable to +// Wireserver +func (w *WireserverTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // the original path of the request must be prefixed with wireserver's path + origPath := req.URL.Path + path := WirePrefix + if req.URL.Path != "" { + path += req.URL.Path[1:] + } + + // the query string from the request must have its constituent parts (?,=,&) + // transformed to slashes and appended to the query + if req.URL.RawQuery != "" { + query := req.URL.RawQuery + query = strings.ReplaceAll(query, "?", "/") + query = strings.ReplaceAll(query, "=", "/") + query = strings.ReplaceAll(query, "&", "/") + path += "/" + query + } + + req.URL.Path = path + // ensure that nothing has changed from the caller's perspective by resetting + // the URL + defer func() { + req.URL.Path = origPath + }() + + // wireserver cannot tolerate PUT requests, so it's necessary to transform those to POSTs + if req.Method == http.MethodPut { + req.Method = http.MethodPost + defer func() { + req.Method = http.MethodPut + }() + } + + // all POST requests (and by extension, PUT) must have a non-nil body + if req.Method == http.MethodPost && req.Body == nil { + req.Body = io.NopCloser(strings.NewReader("")) + } + + // execute the request to the downstream transport + resp, err := w.Transport.RoundTrip(req) + if err != nil { + return resp, err + } + // we want to close this because we're going to replace it + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return resp, nil + } + + // correct the HTTP status returned from wireserver + var wsResp WireserverResponse + err = json.NewDecoder(resp.Body).Decode(&wsResp) + if err != nil { + return resp, fmt.Errorf("decoding json response from wireserver: %w", err) + } + + // set the response status code with the *real* status code + realCode, err := wsResp.StatusCode() + if err != nil { + return resp, fmt.Errorf("retrieving status code from wireserver response: %w", err) + } + + resp.StatusCode = realCode + + // re-encode the body and re-attach it to the response + delete(wsResp, "httpStatusCode") // TODO(timraymond): concern of the response + + body, err := json.Marshal(wsResp) + if err != nil { + return resp, fmt.Errorf("re-encoding json response from wireserver: %w", err) + } + + resp.Body = io.NopCloser(bytes.NewReader(body)) + + return resp, nil +} diff --git a/nmagent/internal/wireserver_test.go b/nmagent/internal/wireserver_test.go new file mode 100644 index 0000000000..f78abc5204 --- /dev/null +++ b/nmagent/internal/wireserver_test.go @@ -0,0 +1,384 @@ +package internal + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/go-cmp/cmp" +) + +type TestTripper struct { + // TODO(timraymond): this entire struct is duplicated + RoundTripF func(*http.Request) (*http.Response, error) +} + +func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripF(req) +} + +func TestWireserverTransportPathTransform(t *testing.T) { + // Wireserver introduces specific rules on how requests should be + // transformed. This test ensures we got those correct. + + pathTests := []struct { + name string + method string + sub string + exp string + }{ + { + "happy path", + http.MethodGet, + "/test/path", + "/machine/plugins/?comp=nmagent&type=test/path", + }, + { + "empty", + http.MethodGet, + "", + "/machine/plugins/?comp=nmagent&type=", + }, + { + "monopath", + http.MethodGet, + "/foo", + "/machine/plugins/?comp=nmagent&type=foo", + }, + } + + for _, test := range pathTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(r *http.Request) (*http.Response, error) { + got = r.URL.Path + rr := httptest.NewRecorder() + rr.WriteHeader(http.StatusOK) + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + return rr.Result(), nil + }, + }, + }, + } + + // execute + + req, err := http.NewRequest(test.method, test.sub, nil) + if err != nil { + t.Fatal("error creating new request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("unexpected error submitting request: err:", err) + } + + // assert + if got != test.exp { + t.Error("received path differs from expectation: exp:", test.exp, "got:", got) + } + }) + } +} + +func TestWireserverTransportStatusTransform(t *testing.T) { + // Wireserver only responds with 200 or 400 and embeds the actual status code + // in JSON. The Transport should correct this and return the actual status as + // an actual status + + statusTests := []struct { + name string + response map[string]interface{} + expBody map[string]interface{} + expStatus int + }{ + { + "401", + map[string]interface{}{ + "httpStatusCode": "401", + }, + map[string]interface{}{}, + http.StatusUnauthorized, + }, + { + "200 with body", + map[string]interface{}{ + "httpStatusCode": "200", + "some": "data", + }, + map[string]interface{}{ + "some": "data", + }, + http.StatusOK, + }, + } + + for _, test := range statusTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(r *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + // mimic Wireserver handing back a 200 regardless: + rr.WriteHeader(http.StatusOK) + + err := json.NewEncoder(rr).Encode(&test.response) + if err != nil { + return nil, fmt.Errorf("encoding json response: %w", err) + } + + return rr.Result(), nil + }, + }, + }, + } + + // execute + + req, err := http.NewRequest(http.MethodGet, "/test/path", nil) + if err != nil { + t.Fatal("error creating new request: err:", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal("unexpected error submitting request: err:", err) + } + defer resp.Body.Close() + + // assert + gotStatus := resp.StatusCode + if gotStatus != test.expStatus { + t.Errorf("status codes differ: exp: (%d) %s: got (%d): %s", test.expStatus, http.StatusText(test.expStatus), gotStatus, http.StatusText(gotStatus)) + } + + var gotBody map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&gotBody) + if err != nil { + t.Fatal("unexpected error decoding json body: err:", err) + } + + if !cmp.Equal(test.expBody, gotBody) { + t.Error("received body differs from expected: diff:", cmp.Diff(test.expBody, gotBody)) + } + }) + } +} + +func TestWireserverTransportPutPost(t *testing.T) { + // wireserver can't tolerate PUT requests, so they must be transformed to POSTs + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.Method + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + req, err := http.NewRequest(http.MethodPut, "/test/path", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + exp := http.MethodPost + if got != exp { + t.Error("unexpected status: exp:", exp, "got:", got) + } +} + +func TestWireserverTransportPostBody(t *testing.T) { + // all PUT and POST requests must have an empty string body + t.Parallel() + + bodyIsNil := false + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + bodyIsNil = req.Body == nil + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + // PUT + req, err := http.NewRequest(http.MethodPut, "/test/path", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + if bodyIsNil { + t.Error("downstream request body to wireserver was nil, but not expected to be") + } + + // POST + req, err = http.NewRequest(http.MethodPost, "/test/path", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + if bodyIsNil { + t.Error("downstream request body to wireserver was nil, but not expected to be") + } +} + +func TestWireserverTransportQuery(t *testing.T) { + // the query string must have its constituent parts converted to slashes and + // appended to the path + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + req, err := http.NewRequest(http.MethodPut, "/test/path?api-version=1234&foo=bar", nil) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + _, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + + exp := "/machine/plugins/?comp=nmagent&type=test/path/api-version/1234/foo/bar" + if got != exp { + t.Error("received request differs from expectation: got:", got, "want:", exp) + } +} + +func TestWireserverResponse(t *testing.T) { + wsRespTests := []struct { + name string + resp string + exp int + shouldErr bool + }{ + { + "empty", + "{}", + 0, + true, + }, + { + "happy path", + `{ + "httpStatusCode": "401" + }`, + 401, + false, + }, + { + "missing code", + `{ + "httpStatusCode": "" + }`, + 0, + true, + }, + { + "other stuff", + `{ + "httpStatusCode": "201", + "other": "stuff" + }`, + 201, + false, + }, + { + "not a string", + `{ + "httpStatusCode": 201, + "other": "stuff" + }`, + 0, + true, + }, + { + "processing", + `{ + "httpStatusCode": "102", + "other": "stuff" + }`, + 102, + false, + }, + } + + for _, test := range wsRespTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var resp WireserverResponse + err := json.Unmarshal([]byte(test.resp), &resp) + if err != nil { + t.Fatal("unexpected unmarshaling error: err:", err) + } + + got, err := resp.StatusCode() + if err != nil && !test.shouldErr { + t.Fatal("unexpected error retrieving status code: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("no error received when one was expected") + } + + if got != test.exp { + t.Error("received incorrect code: got:", got, "want:", test.exp) + } + }) + } +} From 0a3903aff3b57a8e6d0dce6f3a97f94739c3a0fd Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 18 Mar 2022 20:06:45 -0400 Subject: [PATCH 07/44] Use WireserverTransport This required some changes to the test so that the WireserverTransport middleware could take effect always --- nmagent/client.go | 29 +++-- nmagent/client_helpers_test.go | 18 +++ nmagent/client_test.go | 206 +++++++++++++++------------------ 3 files changed, 133 insertions(+), 120 deletions(-) create mode 100644 nmagent/client_helpers_test.go diff --git a/nmagent/client.go b/nmagent/client.go index 928458da6b..dc8e4aa7a7 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -8,20 +8,33 @@ import ( "net" "net/http" "net/url" - "strings" "time" "github.com/google/uuid" ) const ( - JoinNetworkPath string = "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/%s/api-version/1" - GetNetworkConfigPath string = "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + JoinNetworkPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + GetNetworkConfigPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" ) +// NewClient returns an initialized Client using the provided configuration +func NewClient(host, port string, grace time.Duration) *Client { + return &Client{ + httpClient: &http.Client{ + Transport: &internal.WireserverTransport{ + Transport: http.DefaultTransport, + }, + }, + Host: host, + Port: port, + UnauthorizedGracePeriod: grace, + } +} + // Client is an agent for exchanging information with NMAgent type Client struct { - HTTPClient *http.Client + httpClient *http.Client // config Host string @@ -48,13 +61,13 @@ func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { Path: fmt.Sprintf(JoinNetworkPath, networkID), } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL.String(), strings.NewReader("")) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL.String(), nil) if err != nil { return fmt.Errorf("creating request: %w", err) } err = internal.BackoffRetry(ctx, func() error { - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("executing request: %w", err) } @@ -82,13 +95,13 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (Vi var out VirtualNetwork - req, err := http.NewRequestWithContext(ctx, http.MethodGet, path.String(), strings.NewReader("")) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path.String(), nil) if err != nil { return out, fmt.Errorf("creating http request to %q: %w", path.String(), err) } err = internal.BackoffRetry(ctx, func() error { - resp, err := c.HTTPClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("executing http request to %q: %w", path.String(), err) } diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go new file mode 100644 index 0000000000..1578abad87 --- /dev/null +++ b/nmagent/client_helpers_test.go @@ -0,0 +1,18 @@ +package nmagent + +import ( + "dnc/nmagent/internal" + "net/http" +) + +// NewTestClient creates an NMAgent Client suitable for use in tests. This is +// unavailable in production builds +func NewTestClient(tripper http.RoundTripper) *Client { + return &Client{ + httpClient: &http.Client{ + Transport: &internal.WireserverTransport{ + Transport: tripper, + }, + }, + } +} diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 659203f391..312e1df694 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -70,18 +70,15 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { // create a client var got string - client := nmagent.Client{ - HTTPClient: &http.Client{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - got = req.URL.Path - rr := httptest.NewRecorder() - rr.WriteHeader(test.respStatus) - return rr.Result(), nil - }, - }, + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + rr.Write([]byte(fmt.Sprintf(`{"httpStatusCode":"%d"}`, test.respStatus))) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil }, - } + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -117,22 +114,19 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { invocations := 0 exp := 10 - client := nmagent.Client{ - HTTPClient: &http.Client{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - if invocations < exp { - rr.WriteHeader(http.StatusProcessing) - invocations++ - } else { - rr.WriteHeader(http.StatusOK) - } - return rr.Result(), nil - }, - }, + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if invocations < exp { + rr.WriteHeader(http.StatusProcessing) + invocations++ + } else { + rr.WriteHeader(http.StatusOK) + } + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + return rr.Result(), nil }, - } + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -164,23 +158,21 @@ func TestNMAgentClientJoinNetworkUnauthorized(t *testing.T) { invocations := 0 exp := 10 - client := nmagent.Client{ - UnauthorizedGracePeriod: 1 * time.Minute, - HTTPClient: &http.Client{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - if invocations < exp { - rr.WriteHeader(http.StatusUnauthorized) - invocations++ - } else { - rr.WriteHeader(http.StatusOK) - } - return rr.Result(), nil - }, - }, + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if invocations < exp { + rr.WriteHeader(http.StatusUnauthorized) + invocations++ + } else { + rr.WriteHeader(http.StatusOK) + } + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + return rr.Result(), nil }, - } + }) + + client.UnauthorizedGracePeriod = 1 * time.Minute // if the test provides a timeout, use it in the context var ctx context.Context @@ -208,7 +200,7 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { name string vnetID string expURL string - expVNet nmagent.VirtualNetwork + expVNet map[string]interface{} shouldCall bool shouldErr bool }{ @@ -216,16 +208,17 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { "happy path", "00000000-0000-0000-0000-000000000000", "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", - nmagent.VirtualNetwork{ - CNetSpace: "10.10.1.0/24", - DefaultGateway: "10.10.0.1", - DNSServers: []string{ + map[string]interface{}{ + "httpStatusCode": "200", + "cnetSpace": "10.10.1.0/24", + "defaultGateway": "10.10.0.1", + "dnsServers": []string{ "1.1.1.1", "1.0.0.1", }, - Subnets: []nmagent.Subnet{}, - VNetSpace: "10.0.0.0/8", - VNetVersion: "2018", // TODO(timraymond): what's a real version look like? + "subnets": []map[string]interface{}{}, + "vnetSpace": "10.0.0.0/8", + "vnetVersion": "12345", }, true, false, @@ -238,23 +231,19 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { t.Parallel() var got string - client := &nmagent.Client{ - HTTPClient: &http.Client{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - got = req.URL.Path - rr.WriteHeader(http.StatusOK) - err := json.NewEncoder(rr).Encode(&test.expVNet) - if err != nil { - return nil, fmt.Errorf("encoding response: %w", err) - } - - return rr.Result(), nil - }, - }, + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + got = req.URL.Path + rr.WriteHeader(http.StatusOK) + err := json.NewEncoder(rr).Encode(&test.expVNet) + if err != nil { + return nil, fmt.Errorf("encoding response: %w", err) + } + + return rr.Result(), nil }, - } + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -279,8 +268,17 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { t.Error("unexpected URL: got:", got, "exp:", test.expURL) } - if !cmp.Equal(gotVNet, test.expVNet) { - t.Error("received vnet differs from expected: diff:", cmp.Diff(gotVNet, test.expVNet)) + // TODO(timraymond): this is ugly + expVnet := nmagent.VirtualNetwork{ + CNetSpace: test.expVNet["cnetSpace"].(string), + DefaultGateway: test.expVNet["defaultGateway"].(string), + DNSServers: test.expVNet["dnsServers"].([]string), + Subnets: []nmagent.Subnet{}, + VNetSpace: test.expVNet["vnetSpace"].(string), + VNetVersion: test.expVNet["vnetVersion"].(string), + } + if !cmp.Equal(gotVNet, expVnet) { + t.Error("received vnet differs from expected: diff:", cmp.Diff(gotVNet, expVnet)) } }) } @@ -291,30 +289,21 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { count := 0 exp := 10 - client := &nmagent.Client{ - HTTPClient: &http.Client{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - if count < exp { - rr.WriteHeader(http.StatusProcessing) - count++ - } else { - rr.WriteHeader(http.StatusOK) - } - - // we still need a fake response - var out nmagent.VirtualNetwork - err := json.NewEncoder(rr).Encode(&out) - if err != nil { - return nil, err - } + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if count < exp { + rr.WriteHeader(http.StatusProcessing) + count++ + } else { + rr.WriteHeader(http.StatusOK) + } - return rr.Result(), nil - }, - }, + // we still need a fake response + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + return rr.Result(), nil }, - } + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -341,31 +330,24 @@ func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { count := 0 exp := 10 - client := &nmagent.Client{ - UnauthorizedGracePeriod: 1 * time.Minute, - HTTPClient: &http.Client{ - Transport: &TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - if count < exp { - rr.WriteHeader(http.StatusUnauthorized) - count++ - } else { - rr.WriteHeader(http.StatusOK) - } + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if count < exp { + rr.WriteHeader(http.StatusUnauthorized) + count++ + } else { + rr.WriteHeader(http.StatusOK) + } - // we still need a fake response - var out nmagent.VirtualNetwork - err := json.NewEncoder(rr).Encode(&out) - if err != nil { - return nil, err - } + // we still need a fake response + rr.Write([]byte(`{"httpStatusCode": "200"}`)) - return rr.Result(), nil - }, - }, + return rr.Result(), nil }, - } + }) + + client.UnauthorizedGracePeriod = 1 * time.Minute // if the test provides a timeout, use it in the context var ctx context.Context From e8a28c2554fc8a69a4550020e421ff4642fffb0e Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 21 Mar 2022 15:02:31 -0400 Subject: [PATCH 08/44] Add PutNetworkContainer method to NMAgent client This is another API that must be implemented --- nmagent/client.go | 38 +++++++++++++++++++-- nmagent/client_test.go | 70 ++++++++++++++++++++++++++++++++++++++ nmagent/requests.go | 73 ++++++++++++++++++++++++++++++++++++++++ nmagent/requests_test.go | 52 ++++++++++++++++++++++++++++ 4 files changed, 230 insertions(+), 3 deletions(-) create mode 100644 nmagent/requests.go create mode 100644 nmagent/requests_test.go diff --git a/nmagent/client.go b/nmagent/client.go index dc8e4aa7a7..4f726e62d2 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -1,6 +1,7 @@ package nmagent import ( + "bytes" "context" "dnc/nmagent/internal" "encoding/json" @@ -16,6 +17,7 @@ import ( const ( JoinNetworkPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" GetNetworkConfigPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" ) // NewClient returns an initialized Client using the provided configuration @@ -126,8 +128,35 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (Vi return out, nil } -/* -func (c *Client) PutNetworkContainer(ctx context.Context) error { +// PutNetworkContainer applies a Network Container goal state and publishes it +// to PubSub +func (c *Client) PutNetworkContainer(ctx context.Context, nc NetworkContainerRequest) error { + requestStart := time.Now() + + path := &url.URL{ + Scheme: "https", + Host: c.hostPort(), + Path: fmt.Sprintf(PutNCRequestPath, nc.PrimaryAddress, nc.ID, nc.AuthenticationToken), + } + + body, err := json.Marshal(nc) + if err != nil { + return fmt.Errorf("encoding request as JSON: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, path.String(), bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("submitting request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return c.error(time.Since(requestStart), resp.StatusCode) + } return nil } @@ -138,7 +167,10 @@ func (c *Client) DeleteNetworkContainer(ctx context.Context) error { func (c *Client) GetNmAgentSupportedApiURLFmt(ctx context.Context) error { return nil } -*/ + +func (c *Client) hostPort() string { + return net.JoinHostPort(c.Host, c.Port) +} // error constructs a NMAgent error while providing some information configured // at instantiation diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 312e1df694..31c0ba816d 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -368,3 +368,73 @@ func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { t.Error("unexpected number of API calls: exp:", exp, "got:", count) } } + +func TestNMAgentPutNetworkContainer(t *testing.T) { + putNCTests := []struct { + name string + req nmagent.NetworkContainerRequest + shouldCall bool + shouldErr bool + }{ + { + "happy path", + nmagent.NetworkContainerRequest{ + ID: "350f1e3c-4283-4f51-83a1-c44253962ef1", + Version: uint64(12345), + VNetID: "be3a33e-61e3-42c7-bd23-6b949f57bd36", + SubnetName: "TestSubnet", + IPv4Addrs: []string{"10.0.0.43"}, + Policies: []nmagent.Policy{ + { + ID: "policyID1", + Type: "type1", + }, + { + ID: "policyID2", + Type: "type2", + }, + }, + VlanID: 1234, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + true, + false, + }, + } + + for _, test := range putNCTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + didCall := false + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + rr.WriteHeader(http.StatusOK) + didCall = true + return rr.Result(), nil + }, + }) + + err := client.PutNetworkContainer(context.TODO(), test.req) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if test.shouldCall && !didCall { + t.Fatal("expected call but received none") + } + + if !test.shouldCall && didCall { + t.Fatal("unexpected call. expected no call ") + } + }) + } +} diff --git a/nmagent/requests.go b/nmagent/requests.go new file mode 100644 index 0000000000..6882bc9474 --- /dev/null +++ b/nmagent/requests.go @@ -0,0 +1,73 @@ +package nmagent + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "unicode" +) + +type Policy struct { + ID string + Type string +} + +// MarshalJson encodes policies as a JSON string, separated by a comma. This +// specific format is requested by the NMAgent documentation +func (p Policy) MarshalJSON() ([]byte, error) { + out := bytes.NewBufferString(p.ID) + out.WriteString(", ") + out.WriteString(p.Type) + + outStr := out.String() + return json.Marshal(outStr) +} + +// UnmarshalJSON decodes a JSON-encoded policy string +func (p *Policy) UnmarshalJSON(in []byte) error { + var raw string + err := json.Unmarshal(in, &raw) + if err != nil { + return fmt.Errorf("decoding policy: %w", err) + } + + parts := strings.Split(raw, ",") + if len(parts) != 2 { + return fmt.Errorf("policies must be two comma-separated values") + } + + p.ID = strings.TrimFunc(parts[0], unicode.IsSpace) + p.Type = strings.TrimFunc(parts[1], unicode.IsSpace) + + return nil +} + +type NetworkContainerRequest struct { + ID string `json:"networkContainerID"` // the id of the network container + VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet + Version uint64 `json:"version"` // the new network container version + + // SubnetName is the name of the delegated subnet. This is used to + // authenticate the request. The list of ipv4addresses must be contained in + // the subnet's prefix. + SubnetName string `json:"subnetName"` + + // IPv4 addresses in the customer virtual network that will be assigned to + // the interface. + IPv4Addrs []string `json:"ipV4Addresses"` + + Policies []Policy `json:"policies"` // policies applied to the network container + + // VlanID is used to distinguish Network Containers with duplicate customer + // addresses. "0" is considered a default value by the API. + VlanID int `json:"vlanId"` + + // AuthenticationToken is the base64 security token for the subnet containing + // the Network Container addresses + AuthenticationToken string `json:"-"` + + // PrimaryAddress is the primary customer address of the interface in the + // management VNet + PrimaryAddress string `json:"-"` +} diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go new file mode 100644 index 0000000000..f7640440ba --- /dev/null +++ b/nmagent/requests_test.go @@ -0,0 +1,52 @@ +package nmagent_test + +import ( + "dnc/nmagent" + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestPolicyMarshal(t *testing.T) { + policyTests := []struct { + name string + policy nmagent.Policy + exp string + }{ + { + "basic", + nmagent.Policy{ + ID: "policyID1", + Type: "type1", + }, + "\"policyID1, type1\"", + }, + } + + for _, test := range policyTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + got, err := json.Marshal(test.policy) + if err != nil { + t.Fatal("unexpected err marshaling policy: err", err) + } + + if string(got) != test.exp { + t.Errorf("marshaled policy does not match expectation: got: %q: exp: %q", string(got), test.exp) + } + + var enc nmagent.Policy + err = json.Unmarshal(got, &enc) + if err != nil { + t.Fatal("unexpected error unmarshaling: err:", err) + } + + if !cmp.Equal(enc, test.policy) { + t.Error("re-encoded policy differs from expectation: diff:", cmp.Diff(enc, test.policy)) + } + }) + } +} From 32786867d081ec642a8064a25ab1ae118f88ef97 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 21 Mar 2022 20:11:39 -0400 Subject: [PATCH 09/44] Switch NMAgent client port to uint16 Ports are uint16s by definition. --- nmagent/client.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 4f726e62d2..f6f1f89e28 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "strconv" "time" "github.com/google/uuid" @@ -21,7 +22,7 @@ const ( ) // NewClient returns an initialized Client using the provided configuration -func NewClient(host, port string, grace time.Duration) *Client { +func NewClient(host string, port uint16, grace time.Duration) *Client { return &Client{ httpClient: &http.Client{ Transport: &internal.WireserverTransport{ @@ -40,7 +41,7 @@ type Client struct { // config Host string - Port string + Port uint16 // UnauthorizedGracePeriod is the amount of time Unauthorized responses from // NMAgent will be tolerated and retried @@ -59,7 +60,7 @@ func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { joinURL := &url.URL{ Scheme: "https", - Host: net.JoinHostPort(c.Host, c.Port), + Host: c.hostPort(), Path: fmt.Sprintf(JoinNetworkPath, networkID), } @@ -91,7 +92,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (Vi path := &url.URL{ Scheme: "https", - Host: net.JoinHostPort(c.Host, c.Port), + Host: c.hostPort(), Path: fmt.Sprintf(GetNetworkConfigPath, vnetID), } @@ -169,7 +170,8 @@ func (c *Client) GetNmAgentSupportedApiURLFmt(ctx context.Context) error { } func (c *Client) hostPort() string { - return net.JoinHostPort(c.Host, c.Port) + port := strconv.Itoa(int(c.Port)) + return net.JoinHostPort(c.Host, port) } // error constructs a NMAgent error while providing some information configured From 748680fef1ddfb44a2d6e40648b1b109102884c3 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 21 Mar 2022 20:12:28 -0400 Subject: [PATCH 10/44] Add missing body close and context propagation --- nmagent/client.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nmagent/client.go b/nmagent/client.go index f6f1f89e28..2220b2d477 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -145,7 +145,7 @@ func (c *Client) PutNetworkContainer(ctx context.Context, nc NetworkContainerReq return fmt.Errorf("encoding request as JSON: %w", err) } - req, err := http.NewRequest(http.MethodPost, path.String(), bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, path.String(), bytes.NewReader(body)) if err != nil { return fmt.Errorf("creating request: %w", err) } @@ -154,6 +154,7 @@ func (c *Client) PutNetworkContainer(ctx context.Context, nc NetworkContainerReq if err != nil { return fmt.Errorf("submitting request: %w", err) } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return c.error(time.Since(requestStart), resp.StatusCode) From bf2f713c4a870a9a651d296785a3b0fa1d122d30 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 21 Mar 2022 20:13:01 -0400 Subject: [PATCH 11/44] Add DeleteNetworkContainer endpoint --- nmagent/client.go | 33 +++++++++++++++++++++--- nmagent/client_test.go | 48 +++++++++++++++++++++++++++++++++++ nmagent/requests.go | 54 ++++++++++++++++++++++++++++++++++++++++ nmagent/requests_test.go | 54 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 4 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 2220b2d477..54c637adf8 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -162,11 +162,36 @@ func (c *Client) PutNetworkContainer(ctx context.Context, nc NetworkContainerReq return nil } -func (c *Client) DeleteNetworkContainer(ctx context.Context) error { - return nil -} +// DeleteNetworkContainer removes a Network Container, its associated IP +// addresses, and network policies from an interface +func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainerRequest) error { + requestStart := time.Now() + + if err := dcr.Validate(); err != nil { + return fmt.Errorf("validating request: %w", err) + } + + path := &url.URL{ + Scheme: "https", + Host: c.hostPort(), + Path: dcr.Path(), + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, path.String(), nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("submitting request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return c.error(time.Since(requestStart), resp.StatusCode) + } -func (c *Client) GetNmAgentSupportedApiURLFmt(ctx context.Context) error { return nil } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 31c0ba816d..16fb2ba330 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -438,3 +438,51 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { }) } } + +func TestNMAgentDeleteNC(t *testing.T) { + deleteTests := []struct { + name string + req nmagent.DeleteContainerRequest + exp string + shouldErr bool + }{ + { + "happy path", + nmagent.DeleteContainerRequest{ + NCID: "00000000-0000-0000-0000-000000000000", + PrimaryAddress: "10.0.0.1", + AuthenticationToken: "swordfish", + }, + "/machine/plugins/?comp=nmagent&type=NetworkManagement/interfaces/10.0.0.1/networkContainers/00000000-0000-0000-0000-000000000000/authenticationToken/swordfish/api-version/1/method/DELETE", + false, + }, + } + + var got string + for _, test := range deleteTests { + test := test + t.Run(test.name, func(t *testing.T) { + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + rr.Write([]byte(`{"httpStatusCode": "200"}`)) + return rr.Result(), nil + }, + }) + + err := client.DeleteNetworkContainer(context.TODO(), test.req) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if test.exp != got { + t.Errorf("received URL differs from expectation:\n\texp: %q:\n\tgot: %q", test.exp, got) + } + }) + } +} diff --git a/nmagent/requests.go b/nmagent/requests.go index 6882bc9474..534228a430 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -71,3 +71,57 @@ type NetworkContainerRequest struct { // management VNet PrimaryAddress string `json:"-"` } + +// DeleteContainerRequest represents all information necessary to request that +// NMAgent delete a particular network container +type DeleteContainerRequest struct { + NCID string `json:"-"` // the Network Container ID + + // PrimaryAddress is the primary customer address of the interface in the + // management VNET + PrimaryAddress string `json:"-"` + AuthenticationToken string `json:"-"` +} + +// Path returns the path for submitting a DeleteContainerRequest with +// parameters interpolated correctly +func (d DeleteContainerRequest) Path() string { + const DeleteNCPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1/method/DELETE" + return fmt.Sprintf(DeleteNCPath, d.PrimaryAddress, d.NCID, d.AuthenticationToken) +} + +type ValidationError struct { + MissingFields []string +} + +func (v ValidationError) Error() string { + return fmt.Sprintf("missing fields: %s", strings.Join(v.MissingFields, ", ")) +} + +func (v ValidationError) IsEmpty() bool { + return len(v.MissingFields) == 0 +} + +// Validate ensures that the DeleteContainerRequest has the correct information +// to submit the request +func (d DeleteContainerRequest) Validate() error { + errs := ValidationError{} + + if d.NCID == "" { + errs.MissingFields = append(errs.MissingFields, "NCID") + } + + if d.PrimaryAddress == "" { + errs.MissingFields = append(errs.MissingFields, "PrimaryAddress") + } + + if d.AuthenticationToken == "" { + errs.MissingFields = append(errs.MissingFields, "AuthenticationToken") + } + + if !errs.IsEmpty() { + return errs + } + + return nil +} diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index f7640440ba..5145840ef0 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -50,3 +50,57 @@ func TestPolicyMarshal(t *testing.T) { }) } } + +func TestDeleteContainerRequestValidation(t *testing.T) { + dcrTests := []struct { + name string + req nmagent.DeleteContainerRequest + shouldBeValid bool + }{ + { + "empty", + nmagent.DeleteContainerRequest{}, + false, + }, + { + "missing ncid", + nmagent.DeleteContainerRequest{ + PrimaryAddress: "10.0.0.1", + AuthenticationToken: "swordfish", + }, + false, + }, + { + "missing primary address", + nmagent.DeleteContainerRequest{ + NCID: "00000000-0000-0000-0000-000000000000", + AuthenticationToken: "swordfish", + }, + false, + }, + { + "missing auth token", + nmagent.DeleteContainerRequest{ + NCID: "00000000-0000-0000-0000-000000000000", + PrimaryAddress: "10.0.0.1", + }, + false, + }, + } + + for _, test := range dcrTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("unexpected validation errors: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected request to be invalid but wasn't") + } + }) + } +} From 01198679040cf788ccb960a9cffcd9319fd9e3e6 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 22 Mar 2022 13:09:54 -0400 Subject: [PATCH 12/44] Move internal imports to another section It's a bit clearer when internal imports are isolated into one section, standard library imports in another, then finally external imports in another section. --- nmagent/client.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 54c637adf8..efd89053f5 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -3,7 +3,6 @@ package nmagent import ( "bytes" "context" - "dnc/nmagent/internal" "encoding/json" "fmt" "net" @@ -12,7 +11,7 @@ import ( "strconv" "time" - "github.com/google/uuid" + "dnc/nmagent/internal" ) const ( From 22f0fce179a7b9ab9a7f6549929b3e2769bfee9b Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 22 Mar 2022 13:20:06 -0400 Subject: [PATCH 13/44] Additional Validation / Retry improvements This is a bit of a rollup commit, including some additional validation logic for some nmagent requests and also some improvements to the internal retry logic. The retry logic is now a struct, and the client depends only on an interface for retrying. This is to accommodate the existing retry package (which was unknown). The internal Retrier was enhanced to add a configurable Cooldown function with strategies for Fixed backoff, Exponential, and a Max limitation. --- nmagent/client.go | 22 +++--- nmagent/client_helpers_test.go | 3 + nmagent/client_test.go | 6 +- nmagent/internal/internal.go | 54 ++++++++++++++- nmagent/internal/internal_test.go | 18 ++++- nmagent/requests.go | 111 +++++++++++++++++++++--------- nmagent/requests_test.go | 50 ++++++++++++++ 7 files changed, 212 insertions(+), 52 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index efd89053f5..d1d97c0274 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -15,7 +15,6 @@ import ( ) const ( - JoinNetworkPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" GetNetworkConfigPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" ) @@ -31,6 +30,9 @@ func NewClient(host string, port uint16, grace time.Duration) *Client { Host: host, Port: port, UnauthorizedGracePeriod: grace, + Retrier: internal.Retrier{ + Cooldown: internal.Exponential(1*time.Second, 2*time.Second), + }, } } @@ -45,22 +47,24 @@ type Client struct { // UnauthorizedGracePeriod is the amount of time Unauthorized responses from // NMAgent will be tolerated and retried UnauthorizedGracePeriod time.Duration + + Retrier interface { + Do(context.Context, func() error) error + } } // JoinNetwork joins a node to a customer's virtual network -func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { +func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { requestStart := time.Now() - // we need to be a little defensive, because there is no bad request response - // from NMAgent - if _, err := uuid.Parse(networkID); err != nil { - return fmt.Errorf("bad network ID %q: %w", networkID, err) + if err := jnr.Validate(); err != nil { + return fmt.Errorf("validating join network request: %w", err) } joinURL := &url.URL{ Scheme: "https", Host: c.hostPort(), - Path: fmt.Sprintf(JoinNetworkPath, networkID), + Path: jnr.Path(), } req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL.String(), nil) @@ -68,7 +72,7 @@ func (c *Client) JoinNetwork(ctx context.Context, networkID string) error { return fmt.Errorf("creating request: %w", err) } - err = internal.BackoffRetry(ctx, func() error { + err = c.Retrier.Do(ctx, func() error { resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("executing request: %w", err) @@ -102,7 +106,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (Vi return out, fmt.Errorf("creating http request to %q: %w", path.String(), err) } - err = internal.BackoffRetry(ctx, func() error { + err = c.Retrier.Do(ctx, func() error { resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("executing http request to %q: %w", path.String(), err) diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go index 1578abad87..c405b82627 100644 --- a/nmagent/client_helpers_test.go +++ b/nmagent/client_helpers_test.go @@ -14,5 +14,8 @@ func NewTestClient(tripper http.RoundTripper) *Client { Transport: tripper, }, }, + Retrier: internal.Retrier{ + Cooldown: internal.AsFastAsPossible, + }, } } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 16fb2ba330..04e47a64cd 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -92,7 +92,7 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { // attempt to join network // TODO(timraymond): need a more realistic network ID, I think - err := client.JoinNetwork(ctx, test.id) + err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{test.id}) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err:", err) } @@ -139,7 +139,7 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { } // attempt to join network - err := client.JoinNetwork(ctx, "00000000-0000-0000-0000-000000000000") + err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{"00000000-0000-0000-0000-000000000000"}) if err != nil { t.Fatal("unexpected error: err:", err) } @@ -185,7 +185,7 @@ func TestNMAgentClientJoinNetworkUnauthorized(t *testing.T) { } // attempt to join network - err := client.JoinNetwork(ctx, "00000000-0000-0000-0000-000000000000") + err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{"00000000-0000-0000-0000-000000000000"}) if err != nil { t.Fatal("unexpected error: err:", err) } diff --git a/nmagent/internal/internal.go b/nmagent/internal/internal.go index be813672e2..e05589f7e4 100644 --- a/nmagent/internal/internal.go +++ b/nmagent/internal/internal.go @@ -3,6 +3,9 @@ package internal import ( "context" "errors" + "fmt" + "math" + "time" ) type TemporaryError interface { @@ -10,9 +13,14 @@ type TemporaryError interface { Temporary() bool } -// BackoffRetry implements cancellable exponential backoff of some arbitrary -// function -func BackoffRetry(ctx context.Context, run func() error) error { +type Retrier struct { + Cooldown func() error +} + +// Do repeatedly invokes the provided run function while the context remains +// active. It waits in between invocations of the provided functions by +// delegating to the provided Cooldown function +func (r Retrier) Do(ctx context.Context, run func() error) error { for { if err := ctx.Err(); err != nil { return err @@ -23,6 +31,10 @@ func BackoffRetry(ctx context.Context, run func() error) error { // check to see if it's temporary var tempErr TemporaryError if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { + err := r.Cooldown() + if err != nil { + return fmt.Errorf("sleeping during retry: %w", err) + } continue } @@ -32,3 +44,39 @@ func BackoffRetry(ctx context.Context, run func() error) error { return nil } } + +func Max(limit int, f func() error) func() error { + count := 0 + return func() error { + if count >= limit { + return fmt.Errorf("maximum attempts reached (%d)", limit) + } + + err := f() + if err != nil { + return err + } + count++ + return nil + } +} + +func AsFastAsPossible() error { return nil } + +func Exponential(interval time.Duration, base time.Duration) func() error { + count := 0 + return func() error { + increment := math.Pow(float64(base.Nanoseconds()), float64(count)) + delay := interval.Nanoseconds() * int64(increment) + time.Sleep(time.Duration(delay)) + count++ + return nil + } +} + +func Fixed(interval time.Duration) func() error { + return func() error { + time.Sleep(interval) + return nil + } +} diff --git a/nmagent/internal/internal_test.go b/nmagent/internal/internal_test.go index b456ecaf43..71e7701788 100644 --- a/nmagent/internal/internal_test.go +++ b/nmagent/internal/internal_test.go @@ -22,7 +22,11 @@ func TestBackoffRetry(t *testing.T) { ctx := context.Background() - err := BackoffRetry(ctx, func() error { + rt := Retrier{ + Cooldown: AsFastAsPossible, + } + + err := rt.Do(ctx, func() error { if got < exp { got++ return TestError{} @@ -47,7 +51,11 @@ func TestBackoffRetryWithCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - err := BackoffRetry(ctx, func() error { + rt := Retrier{ + Cooldown: AsFastAsPossible, + } + + err := rt.Do(ctx, func() error { got++ if got >= exp { cancel() @@ -69,7 +77,11 @@ func TestBackoffRetryWithCancel(t *testing.T) { } func TestBackoffRetryUnretriableError(t *testing.T) { - err := BackoffRetry(context.Background(), func() error { + rt := Retrier{ + Cooldown: AsFastAsPossible, + } + + err := rt.Do(context.Background(), func() error { return errors.New("boom") }) diff --git a/nmagent/requests.go b/nmagent/requests.go index 534228a430..f936dcdf9d 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -6,8 +6,43 @@ import ( "fmt" "strings" "unicode" + + "github.com/google/uuid" ) +// NetworkContainerRequest {{{1 + +type NetworkContainerRequest struct { + ID string `json:"networkContainerID"` // the id of the network container + VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet + Version uint64 `json:"version"` // the new network container version + + // SubnetName is the name of the delegated subnet. This is used to + // authenticate the request. The list of ipv4addresses must be contained in + // the subnet's prefix. + SubnetName string `json:"subnetName"` + + // IPv4 addresses in the customer virtual network that will be assigned to + // the interface. + IPv4Addrs []string `json:"ipV4Addresses"` + + Policies []Policy `json:"policies"` // policies applied to the network container + + // VlanID is used to distinguish Network Containers with duplicate customer + // addresses. "0" is considered a default value by the API. + VlanID int `json:"vlanId"` + + // AuthenticationToken is the base64 security token for the subnet containing + // the Network Container addresses + AuthenticationToken string `json:"-"` + + // PrimaryAddress is the primary customer address of the interface in the + // management VNet + PrimaryAddress string `json:"-"` +} + +// Policy {{{2 + type Policy struct { ID string Type string @@ -43,35 +78,37 @@ func (p *Policy) UnmarshalJSON(in []byte) error { return nil } -type NetworkContainerRequest struct { - ID string `json:"networkContainerID"` // the id of the network container - VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet - Version uint64 `json:"version"` // the new network container version +// }}}2 - // SubnetName is the name of the delegated subnet. This is used to - // authenticate the request. The list of ipv4addresses must be contained in - // the subnet's prefix. - SubnetName string `json:"subnetName"` +// }}}1 - // IPv4 addresses in the customer virtual network that will be assigned to - // the interface. - IPv4Addrs []string `json:"ipV4Addresses"` +// JoinNetworkRequest {{{1 - Policies []Policy `json:"policies"` // policies applied to the network container - - // VlanID is used to distinguish Network Containers with duplicate customer - // addresses. "0" is considered a default value by the API. - VlanID int `json:"vlanId"` +type JoinNetworkRequest struct { + NetworkID string `json:"-"` // the customer's VNet ID +} - // AuthenticationToken is the base64 security token for the subnet containing - // the Network Container addresses - AuthenticationToken string `json:"-"` +// Path constructs a URL path for invoking a JoinNetworkRequest using the +// provided parameters +func (j JoinNetworkRequest) Path() string { + const JoinNetworkPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + return fmt.Sprintf(JoinNetworkPath, j.NetworkID) +} - // PrimaryAddress is the primary customer address of the interface in the - // management VNet - PrimaryAddress string `json:"-"` +// Validate ensures that the provided parameters of the request are valid +func (j JoinNetworkRequest) Validate() error { + // we need to be a little defensive, because there is no bad request response + // from NMAgent + if _, err := uuid.Parse(j.NetworkID); err != nil { + return fmt.Errorf("bad network ID %q: %w", j.NetworkID, err) + } + return nil } +// }}}1 + +// DeleteNetworkRequest {{{1 + // DeleteContainerRequest represents all information necessary to request that // NMAgent delete a particular network container type DeleteContainerRequest struct { @@ -90,18 +127,6 @@ func (d DeleteContainerRequest) Path() string { return fmt.Sprintf(DeleteNCPath, d.PrimaryAddress, d.NCID, d.AuthenticationToken) } -type ValidationError struct { - MissingFields []string -} - -func (v ValidationError) Error() string { - return fmt.Sprintf("missing fields: %s", strings.Join(v.MissingFields, ", ")) -} - -func (v ValidationError) IsEmpty() bool { - return len(v.MissingFields) == 0 -} - // Validate ensures that the DeleteContainerRequest has the correct information // to submit the request func (d DeleteContainerRequest) Validate() error { @@ -125,3 +150,21 @@ func (d DeleteContainerRequest) Validate() error { return nil } + +// }}}1 + +// ValidationError {{{1 + +type ValidationError struct { + MissingFields []string +} + +func (v ValidationError) Error() string { + return fmt.Sprintf("missing fields: %s", strings.Join(v.MissingFields, ", ")) +} + +func (v ValidationError) IsEmpty() bool { + return len(v.MissingFields) == 0 +} + +// }}}1 diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index 5145840ef0..a0f2b0c627 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -104,3 +104,53 @@ func TestDeleteContainerRequestValidation(t *testing.T) { }) } } + +func TestJoinNetworkRequestPath(t *testing.T) { + jnr := nmagent.JoinNetworkRequest{ + NetworkID: "00000000-0000-0000-0000-000000000000", + } + + exp := "/NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1" + if jnr.Path() != exp { + t.Error("unexpected path: exp:", exp, "got:", jnr.Path()) + } +} + +func TestJoinNetworkRequestValidate(t *testing.T) { + validateRequest := []struct { + name string + req nmagent.JoinNetworkRequest + shouldBeValid bool + }{ + { + "invalid", + nmagent.JoinNetworkRequest{ + NetworkID: "4815162342", + }, + false, + }, + { + "valid", + nmagent.JoinNetworkRequest{ + NetworkID: "00000000-0000-0000-0000-000000000000", + }, + true, + }, + } + + for _, test := range validateRequest { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("unexpected error validating: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected request to be invalid but wasn't") + } + }) + } +} From 3a2993be46ee419f81f6a2ddbd0c51ec1ae8a7c4 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 22 Mar 2022 23:55:58 -0400 Subject: [PATCH 14/44] Move GetNetworkConfig request params to a struct This follows the pattern established in other API calls. It moves validation to the request itself and also leaves the responsibility for constructing paths to the request. --- nmagent/client.go | 15 ++++++---- nmagent/client_test.go | 6 ++-- nmagent/requests.go | 31 +++++++++++++++++++ nmagent/requests_test.go | 64 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 9 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index d1d97c0274..077820771b 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -15,8 +15,7 @@ import ( ) const ( - GetNetworkConfigPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" - PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" + PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" ) // NewClient returns an initialized Client using the provided configuration @@ -90,17 +89,21 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error // GetNetworkConfiguration retrieves the configuration of a customer's virtual // network. Only subnets which have been delegated will be returned -func (c *Client) GetNetworkConfiguration(ctx context.Context, vnetID string) (VirtualNetwork, error) { +func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkConfigRequest) (VirtualNetwork, error) { requestStart := time.Now() + var out VirtualNetwork + + if err := gncr.Validate(); err != nil { + return out, fmt.Errorf("validating request: %w", err) + } + path := &url.URL{ Scheme: "https", Host: c.hostPort(), - Path: fmt.Sprintf(GetNetworkConfigPath, vnetID), + Path: gncr.Path(), } - var out VirtualNetwork - req, err := http.NewRequestWithContext(ctx, http.MethodGet, path.String(), nil) if err != nil { return out, fmt.Errorf("creating http request to %q: %w", path.String(), err) diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 04e47a64cd..fe6eafe8c5 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -255,7 +255,7 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { ctx = context.Background() } - gotVNet, err := client.GetNetworkConfiguration(ctx, test.vnetID) + gotVNet, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{test.vnetID}) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err:", err) } @@ -315,7 +315,7 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { ctx = context.Background() } - _, err := client.GetNetworkConfiguration(ctx, "00000000-0000-0000-0000-000000000000") + _, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{"00000000-0000-0000-0000-000000000000"}) if err != nil { t.Fatal("unexpected error: err:", err) } @@ -359,7 +359,7 @@ func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { ctx = context.Background() } - _, err := client.GetNetworkConfiguration(ctx, "00000000-0000-0000-0000-000000000000") + _, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{"00000000-0000-0000-0000-000000000000"}) if err != nil { t.Fatal("unexpected error: err:", err) } diff --git a/nmagent/requests.go b/nmagent/requests.go index f936dcdf9d..e56df9ea2d 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -153,6 +153,37 @@ func (d DeleteContainerRequest) Validate() error { // }}}1 +// GetNetworkConfigRequest {{{1 + +// GetNetworkConfigRequest is a collection of necessary information for +// submitting a request for a customer's network configuration +type GetNetworkConfigRequest struct { + VNetID string `json:"-"` // the customer's virtual network ID +} + +// Path produces a URL path used to submit a request +func (g GetNetworkConfigRequest) Path() string { + const GetNetworkConfigPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + return fmt.Sprintf(GetNetworkConfigPath, g.VNetID) +} + +// Validate ensures that the request is complete and the parameters are correct +func (g GetNetworkConfigRequest) Validate() error { + errs := ValidationError{} + + if g.VNetID == "" { + errs.MissingFields = append(errs.MissingFields, "VNetID") + } + + if !errs.IsEmpty() { + return errs + } + + return nil +} + +// }}}1 + // ValidationError {{{1 type ValidationError struct { diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index a0f2b0c627..565657aa1e 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -154,3 +154,67 @@ func TestJoinNetworkRequestValidate(t *testing.T) { }) } } + +func TestGetNetworkConfigRequestPath(t *testing.T) { + pathTests := []struct { + name string + req nmagent.GetNetworkConfigRequest + exp string + }{ + { + "happy path", + nmagent.GetNetworkConfigRequest{ + VNetID: "00000000-0000-0000-0000-000000000000", + }, + "/NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + }, + } + + for _, test := range pathTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + if got := test.req.Path(); got != test.exp { + t.Error("unexpected path: exp:", test.exp, "got:", got) + } + }) + } +} + +func TestGetNetworkConfigRequestValidate(t *testing.T) { + validateTests := []struct { + name string + req nmagent.GetNetworkConfigRequest + shouldBeValid bool + }{ + { + "happy path", + nmagent.GetNetworkConfigRequest{ + VNetID: "00000000-0000-0000-0000-000000000000", + }, + true, + }, + { + "empty", + nmagent.GetNetworkConfigRequest{}, + false, + }, + } + + for _, test := range validateTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("expected request to be valid but wasn't: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected error to be invalid but wasn't") + } + }) + } +} From f2080ef6dc121727dd39036c4415adb5dceb6475 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 23 Mar 2022 15:13:45 -0400 Subject: [PATCH 15/44] Add Validation and Path to put request To be consistent with the other request types, this adds Validate and Path methods to the PutNetworkContainerRequest --- nmagent/client.go | 14 +-- nmagent/client_test.go | 4 +- nmagent/requests.go | 46 +++++++++- nmagent/requests_test.go | 188 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 241 insertions(+), 11 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 077820771b..154188c637 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -14,10 +14,6 @@ import ( "dnc/nmagent/internal" ) -const ( - PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" -) - // NewClient returns an initialized Client using the provided configuration func NewClient(host string, port uint16, grace time.Duration) *Client { return &Client{ @@ -137,16 +133,20 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon // PutNetworkContainer applies a Network Container goal state and publishes it // to PubSub -func (c *Client) PutNetworkContainer(ctx context.Context, nc NetworkContainerRequest) error { +func (c *Client) PutNetworkContainer(ctx context.Context, pncr PutNetworkContainerRequest) error { requestStart := time.Now() + if err := pncr.Validate(); err != nil { + return fmt.Errorf("validating request: %w", err) + } + path := &url.URL{ Scheme: "https", Host: c.hostPort(), - Path: fmt.Sprintf(PutNCRequestPath, nc.PrimaryAddress, nc.ID, nc.AuthenticationToken), + Path: pncr.Path(), } - body, err := json.Marshal(nc) + body, err := json.Marshal(pncr) if err != nil { return fmt.Errorf("encoding request as JSON: %w", err) } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index fe6eafe8c5..2eb64ba030 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -372,13 +372,13 @@ func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { func TestNMAgentPutNetworkContainer(t *testing.T) { putNCTests := []struct { name string - req nmagent.NetworkContainerRequest + req nmagent.PutNetworkContainerRequest shouldCall bool shouldErr bool }{ { "happy path", - nmagent.NetworkContainerRequest{ + nmagent.PutNetworkContainerRequest{ ID: "350f1e3c-4283-4f51-83a1-c44253962ef1", Version: uint64(12345), VNetID: "be3a33e-61e3-42c7-bd23-6b949f57bd36", diff --git a/nmagent/requests.go b/nmagent/requests.go index e56df9ea2d..1549adaadc 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -10,9 +10,11 @@ import ( "github.com/google/uuid" ) -// NetworkContainerRequest {{{1 +// PutNetworkContainerRequest {{{1 -type NetworkContainerRequest struct { +// PutNetworkContainerRequest is a collection of parameters necessary to create +// a new network container +type PutNetworkContainerRequest struct { ID string `json:"networkContainerID"` // the id of the network container VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet Version uint64 `json:"version"` // the new network container version @@ -32,6 +34,9 @@ type NetworkContainerRequest struct { // addresses. "0" is considered a default value by the API. VlanID int `json:"vlanId"` + // VirtualNetworkID is the ID of the customer's virtual network + VirtualNetworkID string `json:"virtualNetworkId"` + // AuthenticationToken is the base64 security token for the subnet containing // the Network Container addresses AuthenticationToken string `json:"-"` @@ -41,6 +46,43 @@ type NetworkContainerRequest struct { PrimaryAddress string `json:"-"` } +// Path returns the URL path necessary to submit this PutNetworkContainerRequest +func (p PutNetworkContainerRequest) Path() string { + const PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" + return fmt.Sprintf(PutNCRequestPath, p.PrimaryAddress, p.ID, p.AuthenticationToken) +} + +// Validate ensures that all of the required parameters of the request have +// been filled out properly prior to submission to NMAgent +func (p PutNetworkContainerRequest) Validate() error { + var errs ValidationError + + if len(p.IPv4Addrs) == 0 { + errs.MissingFields = append(errs.MissingFields, "IPv4Addrs") + } + + if p.SubnetName == "" { + errs.MissingFields = append(errs.MissingFields, "SubnetName") + } + + // it's a little unclear as to whether a version value of "0" is actually + // legal. Given that this is the zero value of this field, and the + // documentation of NMAgent requires this to be a uint64, we'll consider "0" + // as unset and require it to be something else. + if p.Version == uint64(0) { + errs.MissingFields = append(errs.MissingFields, "Version") + } + + if p.VirtualNetworkID == "" { + errs.MissingFields = append(errs.MissingFields, "VirtualNetworkID") + } + + if errs.IsEmpty() { + return nil + } + return errs +} + // Policy {{{2 type Policy struct { diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index 565657aa1e..63b78e03f6 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -218,3 +218,191 @@ func TestGetNetworkConfigRequestValidate(t *testing.T) { }) } } + +func TestPutNetworkContainerRequestPath(t *testing.T) { + pathTests := []struct { + name string + req nmagent.PutNetworkContainerRequest + exp string + }{ + { + "happy path", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + "10.0.0.3", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + VirtualNetworkID: "33333333-3333-3333-3333-333333333333", + }, + "/NetworkManagement/interfaces/10.0.0.1/networkContainers/00000000-0000-0000-0000-000000000000/authenticationToken/swordfish/api-version/1", + }, + } + + for _, test := range pathTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + if got := test.req.Path(); got != test.exp { + t.Error("path differs from expectation: exp:", test.exp, "got:", got) + } + }) + } +} + +func TestPutNetworkContainerRequestValidate(t *testing.T) { + validationTests := []struct { + name string + req nmagent.PutNetworkContainerRequest + shouldBeValid bool + }{ + { + "empty", + nmagent.PutNetworkContainerRequest{}, + false, + }, + { + "happy", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + "10.0.0.3", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + VirtualNetworkID: "33333333-3333-3333-3333-333333333333", + }, + true, + }, + { + "missing IPv4Addrs", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{}, // the important part + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + VirtualNetworkID: "33333333-3333-3333-3333-333333333333", + }, + false, + }, + { + "missing subnet name", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "", // the important part of the test + IPv4Addrs: []string{ + "10.0.0.2", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + VirtualNetworkID: "33333333-3333-3333-3333-333333333333", + }, + false, + }, + { + "missing version", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(0), // the important part of the test + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + VirtualNetworkID: "33333333-3333-3333-3333-333333333333", + }, + false, + }, + { + "missing version", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + VirtualNetworkID: "", // the important part of the test + }, + false, + }, + } + + for _, test := range validationTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("unexpected error validating: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected validation error but received none") + } + }) + } +} From 6bab759ffd47bf1d3e1ab5938eb2aa4c8a62e464 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 23 Mar 2022 16:22:29 -0400 Subject: [PATCH 16/44] Introduce Request and Option Enough was common among building requests and validating them that it made sense to formalize it as an interface of its own. This allowed centralizing the construction of HTTP requests in the nmagent.Client. As such, it made adding TLS disablement trivial. Since there is some optional behavior that can be configured with the nmagent.Client, nmagent.Option has been introduced to handle this in a clean manner. --- nmagent/client.go | 156 +++++++++++++++++++++-------------------- nmagent/client_test.go | 1 + nmagent/requests.go | 78 +++++++++++++++++++++ 3 files changed, 158 insertions(+), 77 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 154188c637..4cd2ef6a4a 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -1,7 +1,6 @@ package nmagent import ( - "bytes" "context" "encoding/json" "fmt" @@ -14,21 +13,45 @@ import ( "dnc/nmagent/internal" ) +// Option is a functional option for configuration optional behavior in the +// client +type Option func(*Client) + +// InsecureDisableTLS is an option to disable TLS communications with NMAgent +func InsecureDisableTLS() Option { + return func(c *Client) { + c.disableTLS = true + } +} + +// WithUnauthorizedGracePeriod is an option to treat Unauthorized (401) +// responses from NMAgent as temporary errors for a configurable amount of time +func WithUnauthorizedGracePeriod(grace time.Duration) Option { + return func(c *Client) { + c.unauthorizedGracePeriod = grace + } +} + // NewClient returns an initialized Client using the provided configuration -func NewClient(host string, port uint16, grace time.Duration) *Client { - return &Client{ +func NewClient(host string, port uint16, opts ...Option) *Client { + client := &Client{ httpClient: &http.Client{ Transport: &internal.WireserverTransport{ Transport: http.DefaultTransport, }, }, - Host: host, - Port: port, - UnauthorizedGracePeriod: grace, - Retrier: internal.Retrier{ + host: host, + port: port, + retrier: internal.Retrier{ Cooldown: internal.Exponential(1*time.Second, 2*time.Second), }, } + + for _, opt := range opts { + opt(client) + } + + return client } // Client is an agent for exchanging information with NMAgent @@ -36,14 +59,16 @@ type Client struct { httpClient *http.Client // config - Host string - Port uint16 + host string + port uint16 - // UnauthorizedGracePeriod is the amount of time Unauthorized responses from + disableTLS bool + + // unauthorizedGracePeriod is the amount of time Unauthorized responses from // NMAgent will be tolerated and retried - UnauthorizedGracePeriod time.Duration + unauthorizedGracePeriod time.Duration - Retrier interface { + retrier interface { Do(context.Context, func() error) error } } @@ -52,22 +77,12 @@ type Client struct { func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { requestStart := time.Now() - if err := jnr.Validate(); err != nil { - return fmt.Errorf("validating join network request: %w", err) - } - - joinURL := &url.URL{ - Scheme: "https", - Host: c.hostPort(), - Path: jnr.Path(), - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL.String(), nil) + req, err := c.buildRequest(ctx, jnr) if err != nil { - return fmt.Errorf("creating request: %w", err) + return fmt.Errorf("building request: %w", err) } - err = c.Retrier.Do(ctx, func() error { + err = c.retrier.Do(ctx, func() error { resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("executing request: %w", err) @@ -90,25 +105,15 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon var out VirtualNetwork - if err := gncr.Validate(); err != nil { - return out, fmt.Errorf("validating request: %w", err) - } - - path := &url.URL{ - Scheme: "https", - Host: c.hostPort(), - Path: gncr.Path(), - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, path.String(), nil) + req, err := c.buildRequest(ctx, gncr) if err != nil { - return out, fmt.Errorf("creating http request to %q: %w", path.String(), err) + return out, fmt.Errorf("building request: %w", err) } - err = c.Retrier.Do(ctx, func() error { + err = c.retrier.Do(ctx, func() error { resp, err := c.httpClient.Do(req) if err != nil { - return fmt.Errorf("executing http request to %q: %w", path.String(), err) + return fmt.Errorf("executing http request to: %w", err) } defer resp.Body.Close() @@ -118,17 +123,13 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon err = json.NewDecoder(resp.Body).Decode(&out) if err != nil { - return fmt.Errorf("decoding json response for %q: %w", path.String(), err) + return fmt.Errorf("decoding json response: %w", err) } return nil }) - if err != nil { - // no need to wrap, as the retry wrapper is intended to be transparent - return out, err - } - return out, nil + return out, err } // PutNetworkContainer applies a Network Container goal state and publishes it @@ -136,24 +137,9 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon func (c *Client) PutNetworkContainer(ctx context.Context, pncr PutNetworkContainerRequest) error { requestStart := time.Now() - if err := pncr.Validate(); err != nil { - return fmt.Errorf("validating request: %w", err) - } - - path := &url.URL{ - Scheme: "https", - Host: c.hostPort(), - Path: pncr.Path(), - } - - body, err := json.Marshal(pncr) - if err != nil { - return fmt.Errorf("encoding request as JSON: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, path.String(), bytes.NewReader(body)) + req, err := c.buildRequest(ctx, pncr) if err != nil { - return fmt.Errorf("creating request: %w", err) + return fmt.Errorf("building request: %w", err) } resp, err := c.httpClient.Do(req) @@ -173,19 +159,9 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr PutNetworkContain func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainerRequest) error { requestStart := time.Now() - if err := dcr.Validate(); err != nil { - return fmt.Errorf("validating request: %w", err) - } - - path := &url.URL{ - Scheme: "https", - Host: c.hostPort(), - Path: dcr.Path(), - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, path.String(), nil) + req, err := c.buildRequest(ctx, dcr) if err != nil { - return fmt.Errorf("creating request: %w", err) + return fmt.Errorf("building request: %w", err) } resp, err := c.httpClient.Do(req) @@ -202,8 +178,34 @@ func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainer } func (c *Client) hostPort() string { - port := strconv.Itoa(int(c.Port)) - return net.JoinHostPort(c.Host, port) + port := strconv.Itoa(int(c.port)) + return net.JoinHostPort(c.host, port) +} + +func (c *Client) buildRequest(ctx context.Context, req Request) (*http.Request, error) { + if err := req.Validate(); err != nil { + return nil, fmt.Errorf("validating request: %w", err) + } + + fullURL := &url.URL{ + Scheme: c.scheme(), + Host: c.hostPort(), + Path: req.Path(), + } + + body, err := req.Body() + if err != nil { + return nil, fmt.Errorf("retrieving request body: %w", err) + } + + return http.NewRequestWithContext(ctx, req.Method(), fullURL.String(), body) +} + +func (c *Client) scheme() string { + if c.disableTLS { + return "http" + } + return "https" } // error constructs a NMAgent error while providing some information configured @@ -211,7 +213,7 @@ func (c *Client) hostPort() string { func (c *Client) error(runtime time.Duration, code int) error { return Error{ Runtime: runtime, - Limit: c.UnauthorizedGracePeriod, + Limit: c.unauthorizedGracePeriod, Code: code, } } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 2eb64ba030..9882a838dc 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -397,6 +397,7 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { VlanID: 1234, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", + VirtualNetworkID: "0000000-0000-0000-0000-000000000000", }, true, false, diff --git a/nmagent/requests.go b/nmagent/requests.go index 1549adaadc..1c0f60dd1f 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -4,14 +4,39 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "net/http" "strings" "unicode" "github.com/google/uuid" ) +// Request {{{1 + +// Request represents an abstracted HTTP request, capable of validating itself, +// producting a valid Path, Body, and its Method +type Request interface { + // Validate should ensure that the request is valid to submit + Validate() error + + // Path should produce a URL path, complete with any URL parameters + // interpolated + Path() string + + // Body produces the HTTP request body necessary to submit the request + Body() (io.Reader, error) + + // Method returns the HTTP Method to be used for the request. + Method() string +} + +// }}}1 + // PutNetworkContainerRequest {{{1 +var _ Request = PutNetworkContainerRequest{} + // PutNetworkContainerRequest is a collection of parameters necessary to create // a new network container type PutNetworkContainerRequest struct { @@ -46,6 +71,22 @@ type PutNetworkContainerRequest struct { PrimaryAddress string `json:"-"` } +// Body marshals the JSON fields of the request and produces an Reader intended +// for use with an HTTP request +func (p PutNetworkContainerRequest) Body() (io.Reader, error) { + body, err := json.Marshal(p) + if err != nil { + return nil, fmt.Errorf("marshaling PutNetworkContainerRequest: %w", err) + } + + return bytes.NewReader(body), nil +} + +// Method returns the HTTP method for this request type +func (p PutNetworkContainerRequest) Method() string { + return http.MethodPost +} + // Path returns the URL path necessary to submit this PutNetworkContainerRequest func (p PutNetworkContainerRequest) Path() string { const PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" @@ -126,6 +167,8 @@ func (p *Policy) UnmarshalJSON(in []byte) error { // JoinNetworkRequest {{{1 +var _ Request = JoinNetworkRequest{} + type JoinNetworkRequest struct { NetworkID string `json:"-"` // the customer's VNet ID } @@ -137,6 +180,16 @@ func (j JoinNetworkRequest) Path() string { return fmt.Sprintf(JoinNetworkPath, j.NetworkID) } +// Body returns nothing, because JoinNetworkRequest has no request body +func (j JoinNetworkRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method returns the HTTP request method to submit a JoinNetworkRequest +func (j JoinNetworkRequest) Method() string { + return http.MethodPost +} + // Validate ensures that the provided parameters of the request are valid func (j JoinNetworkRequest) Validate() error { // we need to be a little defensive, because there is no bad request response @@ -151,6 +204,8 @@ func (j JoinNetworkRequest) Validate() error { // DeleteNetworkRequest {{{1 +var _ Request = DeleteContainerRequest{} + // DeleteContainerRequest represents all information necessary to request that // NMAgent delete a particular network container type DeleteContainerRequest struct { @@ -169,6 +224,16 @@ func (d DeleteContainerRequest) Path() string { return fmt.Sprintf(DeleteNCPath, d.PrimaryAddress, d.NCID, d.AuthenticationToken) } +// Body returns nothing, because DeleteContainerRequests have no HTTP body +func (d DeleteContainerRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method returns the HTTP method required to submit a DeleteContainerRequest +func (d DeleteContainerRequest) Method() string { + return http.MethodPost +} + // Validate ensures that the DeleteContainerRequest has the correct information // to submit the request func (d DeleteContainerRequest) Validate() error { @@ -197,6 +262,8 @@ func (d DeleteContainerRequest) Validate() error { // GetNetworkConfigRequest {{{1 +var _ Request = GetNetworkConfigRequest{} + // GetNetworkConfigRequest is a collection of necessary information for // submitting a request for a customer's network configuration type GetNetworkConfigRequest struct { @@ -209,6 +276,17 @@ func (g GetNetworkConfigRequest) Path() string { return fmt.Sprintf(GetNetworkConfigPath, g.VNetID) } +// Body returns nothing because GetNetworkConfigRequest has no HTTP request +// body +func (g GetNetworkConfigRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method returns the HTTP method required to submit a GetNetworkConfigRequest +func (g GetNetworkConfigRequest) Method() string { + return http.MethodGet +} + // Validate ensures that the request is complete and the parameters are correct func (g GetNetworkConfigRequest) Validate() error { errs := ValidationError{} From e398a6e1d89f27c788f7d526218fbc408f041992 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 23 Mar 2022 16:24:54 -0400 Subject: [PATCH 17/44] Add additional error documentation The NMAgent documentation contains some additional documentation as to the meaning of particular HTTP Status codes. Since we have this information, it makes sense to enhance the nmagent.Error so it can explain what the problem is. --- nmagent/doc.go | 3 +++ nmagent/nmagent.go | 19 ++++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 nmagent/doc.go diff --git a/nmagent/doc.go b/nmagent/doc.go new file mode 100644 index 0000000000..3d3f7fc597 --- /dev/null +++ b/nmagent/doc.go @@ -0,0 +1,3 @@ +// package nmagent contains types and functions necessary for interacting with +// the Network Manager Agent (NMAgent). +package nmagent diff --git a/nmagent/nmagent.go b/nmagent/nmagent.go index 23a1007228..c68d72d6b3 100644 --- a/nmagent/nmagent.go +++ b/nmagent/nmagent.go @@ -14,8 +14,25 @@ type Error struct { Code int // the HTTP status code received } +// Error constructs a string representation of this error in accordance with +// the error interface func (e Error) Error() string { - return fmt.Sprintf("nmagent: http status %d", e.Code) + return fmt.Sprintf("nmagent: http status %d: %s", e.Code, e.Message()) +} + +// Message interprets the HTTP Status code from NMAgent and returns the +// corresponding explanation from the documentation +func (e Error) Message() string { + switch e.Code { + case http.StatusProcessing: + return "the request is taking time to process. the caller should try the request again" + case http.StatusUnauthorized: + return "the request did not originate from an interface with an OwningServiceInstanceId property" + case http.StatusInternalServerError: + return "error occurred during nmagent's request processing" + default: + return "undocumented nmagent error" + } } // Temporary reports whether the error encountered from NMAgent should be From ff5e8e91e40e198256e0a795ef95280ebdccc121 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 23 Mar 2022 18:45:56 -0400 Subject: [PATCH 18/44] Fix issue with cooldown remembering state Previously, cooldown functions were able to retain state across invocations of the "Do" method of the retrier. This adds an additional layer of functions to allow the Retrier to purge the accumulated state --- nmagent/client.go | 40 ++++++++-------- nmagent/client_helpers_test.go | 24 ++++++---- nmagent/client_test.go | 36 +++++++------- nmagent/internal/internal.go | 78 ++++++++++++++++++++----------- nmagent/internal/internal_test.go | 75 +++++++++++++++++++++++++++-- 5 files changed, 175 insertions(+), 78 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 4cd2ef6a4a..8f9352b2fd 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -13,25 +13,6 @@ import ( "dnc/nmagent/internal" ) -// Option is a functional option for configuration optional behavior in the -// client -type Option func(*Client) - -// InsecureDisableTLS is an option to disable TLS communications with NMAgent -func InsecureDisableTLS() Option { - return func(c *Client) { - c.disableTLS = true - } -} - -// WithUnauthorizedGracePeriod is an option to treat Unauthorized (401) -// responses from NMAgent as temporary errors for a configurable amount of time -func WithUnauthorizedGracePeriod(grace time.Duration) Option { - return func(c *Client) { - c.unauthorizedGracePeriod = grace - } -} - // NewClient returns an initialized Client using the provided configuration func NewClient(host string, port uint16, opts ...Option) *Client { client := &Client{ @@ -43,7 +24,7 @@ func NewClient(host string, port uint16, opts ...Option) *Client { host: host, port: port, retrier: internal.Retrier{ - Cooldown: internal.Exponential(1*time.Second, 2*time.Second), + Cooldown: internal.Exponential(1*time.Second, 2), }, } @@ -73,6 +54,25 @@ type Client struct { } } +// Option is a functional option for configuration optional behavior in the +// client +type Option func(*Client) + +// InsecureDisableTLS is an option to disable TLS communications with NMAgent +func InsecureDisableTLS() Option { + return func(c *Client) { + c.disableTLS = true + } +} + +// WithUnauthorizedGracePeriod is an option to treat Unauthorized (401) +// responses from NMAgent as temporary errors for a configurable amount of time +func WithUnauthorizedGracePeriod(grace time.Duration) Option { + return func(c *Client) { + c.unauthorizedGracePeriod = grace + } +} + // JoinNetwork joins a node to a customer's virtual network func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { requestStart := time.Now() diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go index c405b82627..74e0105c67 100644 --- a/nmagent/client_helpers_test.go +++ b/nmagent/client_helpers_test.go @@ -5,17 +5,23 @@ import ( "net/http" ) -// NewTestClient creates an NMAgent Client suitable for use in tests. This is -// unavailable in production builds -func NewTestClient(tripper http.RoundTripper) *Client { - return &Client{ - httpClient: &http.Client{ +// WithTransport allows a test to specify a particular http.RoundTripper for +// use in testing scenarios +func WithTransport(tripper http.RoundTripper) Option { + return func(c *Client) { + c.httpClient = &http.Client{ Transport: &internal.WireserverTransport{ Transport: tripper, }, - }, - Retrier: internal.Retrier{ - Cooldown: internal.AsFastAsPossible, - }, + } + } +} + +// NoBackoff disables exponential backoff in the client +func NoBackoff() Option { + return func(c *Client) { + c.retrier = internal.Retrier{ + Cooldown: internal.AsFastAsPossible(), + } } } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 9882a838dc..1736b2fd25 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -70,7 +70,7 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { // create a client var got string - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.Path rr := httptest.NewRecorder() @@ -78,7 +78,7 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, - }) + }), nmagent.NoBackoff()) // if the test provides a timeout, use it in the context var ctx context.Context @@ -114,7 +114,7 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { invocations := 0 exp := 10 - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if invocations < exp { @@ -126,7 +126,7 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { rr.Write([]byte(`{"httpStatusCode": "200"}`)) return rr.Result(), nil }, - }) + }), nmagent.NoBackoff()) // if the test provides a timeout, use it in the context var ctx context.Context @@ -158,7 +158,7 @@ func TestNMAgentClientJoinNetworkUnauthorized(t *testing.T) { invocations := 0 exp := 10 - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if invocations < exp { @@ -170,9 +170,7 @@ func TestNMAgentClientJoinNetworkUnauthorized(t *testing.T) { rr.Write([]byte(`{"httpStatusCode": "200"}`)) return rr.Result(), nil }, - }) - - client.UnauthorizedGracePeriod = 1 * time.Minute + }), nmagent.WithUnauthorizedGracePeriod(1*time.Minute), nmagent.NoBackoff()) // if the test provides a timeout, use it in the context var ctx context.Context @@ -231,7 +229,7 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { t.Parallel() var got string - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() got = req.URL.Path @@ -243,7 +241,7 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { return rr.Result(), nil }, - }) + }), nmagent.NoBackoff()) // if the test provides a timeout, use it in the context var ctx context.Context @@ -289,7 +287,7 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { count := 0 exp := 10 - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if count < exp { @@ -303,7 +301,7 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { rr.Write([]byte(`{"httpStatusCode": "200"}`)) return rr.Result(), nil }, - }) + }), nmagent.NoBackoff()) // if the test provides a timeout, use it in the context var ctx context.Context @@ -330,7 +328,7 @@ func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { count := 0 exp := 10 - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if count < exp { @@ -345,9 +343,7 @@ func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { return rr.Result(), nil }, - }) - - client.UnauthorizedGracePeriod = 1 * time.Minute + }), nmagent.WithUnauthorizedGracePeriod(1*time.Minute), nmagent.NoBackoff()) // if the test provides a timeout, use it in the context var ctx context.Context @@ -410,7 +406,7 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { t.Parallel() didCall := false - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() rr.Write([]byte(`{"httpStatusCode": "200"}`)) @@ -418,7 +414,7 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { didCall = true return rr.Result(), nil }, - }) + }), nmagent.NoBackoff()) err := client.PutNetworkContainer(context.TODO(), test.req) if err != nil && !test.shouldErr { @@ -463,14 +459,14 @@ func TestNMAgentDeleteNC(t *testing.T) { for _, test := range deleteTests { test := test t.Run(test.name, func(t *testing.T) { - client := nmagent.NewTestClient(&TestTripper{ + client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.Path rr := httptest.NewRecorder() rr.Write([]byte(`{"httpStatusCode": "200"}`)) return rr.Result(), nil }, - }) + }), nmagent.NoBackoff()) err := client.DeleteNetworkContainer(context.TODO(), test.req) if err != nil && !test.shouldErr { diff --git a/nmagent/internal/internal.go b/nmagent/internal/internal.go index e05589f7e4..2e57c404e5 100644 --- a/nmagent/internal/internal.go +++ b/nmagent/internal/internal.go @@ -14,13 +14,15 @@ type TemporaryError interface { } type Retrier struct { - Cooldown func() error + Cooldown CooldownFactory } // Do repeatedly invokes the provided run function while the context remains // active. It waits in between invocations of the provided functions by // delegating to the provided Cooldown function func (r Retrier) Do(ctx context.Context, run func() error) error { + cooldown := r.Cooldown() + for { if err := ctx.Err(); err != nil { return err @@ -31,7 +33,7 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { // check to see if it's temporary var tempErr TemporaryError if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { - err := r.Cooldown() + err := cooldown() if err != nil { return fmt.Errorf("sleeping during retry: %w", err) } @@ -45,38 +47,62 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { } } -func Max(limit int, f func() error) func() error { - count := 0 - return func() error { - if count >= limit { - return fmt.Errorf("maximum attempts reached (%d)", limit) - } +// CooldownFunc is a function that will block when called. It is intended for +// use with retry logic. +type CooldownFunc func() error - err := f() - if err != nil { - return err +// CooldownFactory is a function that returns CooldownFuncs. It helps +// CooldownFuncs dispose of any accumulated state so that they function +// correctly upon successive uses. +type CooldownFactory func() CooldownFunc + +func Max(limit int, factory CooldownFactory) CooldownFactory { + return func() CooldownFunc { + cooldown := factory() + count := 0 + return func() error { + if count >= limit { + return fmt.Errorf("maximum attempts reached (%d)", limit) + } + + err := cooldown() + if err != nil { + return err + } + count++ + return nil } - count++ - return nil } } -func AsFastAsPossible() error { return nil } +// AsFastAsPossible is a Cooldown strategy that does not block, allowing retry +// logic to proceed as fast as possible. This is particularly useful in tests +func AsFastAsPossible() CooldownFactory { + return func() CooldownFunc { + return func() error { + return nil + } + } +} -func Exponential(interval time.Duration, base time.Duration) func() error { - count := 0 - return func() error { - increment := math.Pow(float64(base.Nanoseconds()), float64(count)) - delay := interval.Nanoseconds() * int64(increment) - time.Sleep(time.Duration(delay)) - count++ - return nil +func Exponential(interval time.Duration, base int) CooldownFactory { + return func() CooldownFunc { + count := 0 + return func() error { + increment := math.Pow(float64(base), float64(count)) + delay := interval.Nanoseconds() * int64(increment) + time.Sleep(time.Duration(delay)) + count++ + return nil + } } } -func Fixed(interval time.Duration) func() error { - return func() error { - time.Sleep(interval) - return nil +func Fixed(interval time.Duration) CooldownFactory { + return func() CooldownFunc { + return func() error { + time.Sleep(interval) + return nil + } } } diff --git a/nmagent/internal/internal_test.go b/nmagent/internal/internal_test.go index 71e7701788..cb62214cf6 100644 --- a/nmagent/internal/internal_test.go +++ b/nmagent/internal/internal_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" ) type TestError struct{} @@ -23,7 +24,7 @@ func TestBackoffRetry(t *testing.T) { ctx := context.Background() rt := Retrier{ - Cooldown: AsFastAsPossible, + Cooldown: AsFastAsPossible(), } err := rt.Do(ctx, func() error { @@ -52,7 +53,7 @@ func TestBackoffRetryWithCancel(t *testing.T) { defer cancel() rt := Retrier{ - Cooldown: AsFastAsPossible, + Cooldown: AsFastAsPossible(), } err := rt.Do(ctx, func() error { @@ -78,7 +79,7 @@ func TestBackoffRetryWithCancel(t *testing.T) { func TestBackoffRetryUnretriableError(t *testing.T) { rt := Retrier{ - Cooldown: AsFastAsPossible, + Cooldown: AsFastAsPossible(), } err := rt.Do(context.Background(), func() error { @@ -89,3 +90,71 @@ func TestBackoffRetryUnretriableError(t *testing.T) { t.Fatal("expected an error, but none was returned") } } + +func TestFixed(t *testing.T) { + exp := 20 * time.Millisecond + + cooldown := Fixed(exp)() + start := time.Now() + + cooldown() + + if got := time.Since(start); got < exp { + t.Fatal("unexpected sleep duration: exp:", exp, "got:", got) + } +} + +func TestExp(t *testing.T) { + interval := 10 * time.Millisecond + base := 2 + + cooldown := Exponential(interval, base)() + + start := time.Now() + cooldown() + + first := time.Since(start) + if first < interval { + t.Fatal("unexpected sleep during first cooldown: exp:", interval, "got:", first) + } + + // ensure that the sleep increases + cooldown() + + second := time.Since(start) + if second < first { + t.Fatal("unexpected sleep during first cooldown: exp:", interval, "got:", second) + } +} + +func TestMax(t *testing.T) { + exp := 10 + got := 0 + + // create a test sleep function + fn := func() CooldownFunc { + return func() error { + got++ + return nil + } + } + + cooldown := Max(10, fn)() + + for i := 0; i < exp; i++ { + err := cooldown() + if err != nil { + t.Fatal("unexpected error from cooldown: err:", err) + } + } + + if exp != got { + t.Error("unexpected number of cooldown invocations: exp:", exp, "got:", got) + } + + // attempt one more, we expect an error + err := cooldown() + if err == nil { + t.Errorf("expected an error after %d invocations but received none", exp+1) + } +} From be0e60cf89cf4318ea18acd0f2fa78609528d298 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Thu, 24 Mar 2022 12:19:38 -0400 Subject: [PATCH 19/44] Move Validation to reflection-based helper The validation logic for each struct was repetitive and it didn't help answer the common question "what fields are required by this request struct." This adds a "validate" struct tag that can be used to annotate fields within the request struct and mark them as required. It's still possible to do arbitrary validation within the Validate method of each request, but the common things like "is this field a zero value?" are abstracted into the internal helper. This also serves as documentation to future readers, making it easier to use the package. --- nmagent/client_test.go | 7 -- nmagent/internal/validate.go | 55 +++++++++++++++ nmagent/internal/validate_test.go | 92 +++++++++++++++++++++++++ nmagent/requests.go | 109 +++++------------------------- nmagent/requests_test.go | 2 +- 5 files changed, 166 insertions(+), 99 deletions(-) create mode 100644 nmagent/internal/validate.go create mode 100644 nmagent/internal/validate_test.go diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 1736b2fd25..d24ab78aad 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -47,13 +47,6 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { http.StatusOK, // this shouldn't be checked true, }, - { - "malformed UUID", - "00000000-0000", - "", - http.StatusOK, // this shouldn't be checked - true, - }, { "internal error", "00000000-0000-0000-0000-000000000000", diff --git a/nmagent/internal/validate.go b/nmagent/internal/validate.go new file mode 100644 index 0000000000..ee28b69b06 --- /dev/null +++ b/nmagent/internal/validate.go @@ -0,0 +1,55 @@ +package internal + +import ( + "fmt" + "reflect" + "strings" +) + +// ValidationError {{{1 + +type ValidationError struct { + MissingFields []string +} + +func (v ValidationError) Error() string { + return fmt.Sprintf("missing fields: %s", strings.Join(v.MissingFields, ", ")) +} + +func (v ValidationError) IsEmpty() bool { + return len(v.MissingFields) == 0 +} + +// }}}1 + +// Validate searches for validate struct tags and performs the validations +// requested by them +func Validate(obj interface{}) error { + errs := ValidationError{} + + val := reflect.ValueOf(obj) + typ := reflect.TypeOf(obj) + + for i := 0; i < val.NumField(); i++ { + fieldVal := val.Field(i) + fieldTyp := typ.Field(i) + + op := fieldTyp.Tag.Get("validate") + switch op { + case "presence": + if fieldVal.Kind() == reflect.Slice { + if fieldVal.Len() == 0 { + errs.MissingFields = append(errs.MissingFields, fieldTyp.Name) + } + } else if fieldVal.IsZero() { + errs.MissingFields = append(errs.MissingFields, fieldTyp.Name) + } + } + } + + if errs.IsEmpty() { + return nil + } + + return errs +} diff --git a/nmagent/internal/validate_test.go b/nmagent/internal/validate_test.go new file mode 100644 index 0000000000..884774d94b --- /dev/null +++ b/nmagent/internal/validate_test.go @@ -0,0 +1,92 @@ +package internal + +import "testing" + +func TestValidate(t *testing.T) { + validateTests := []struct { + name string + sub interface{} + shouldBeValid bool + shouldPanic bool + }{ + { + "empty", + struct{}{}, + true, + false, + }, + { + "no tags", + struct { + Foo string + }{""}, + true, + false, + }, + { + "presence", + struct { + Foo string `validate:"presence"` + }{"hi"}, + true, + false, + }, + { + "presence empty", + struct { + Foo string `validate:"presence"` + }{}, + false, + false, + }, + { + "required empty slice", + struct { + Foo []string `validate:"presence"` + }{}, + false, + false, + }, + { + "not a struct", + 42, + false, + true, + }, + { + "slice", + []interface{}{}, + false, + true, + }, + { + "map", + map[string]interface{}{}, + false, + true, + }, + } + + for _, test := range validateTests { + test := test + t.Run(test.name, func(t *testing.T) { + defer func() { + if err := recover(); err != nil && !test.shouldPanic { + t.Fatal("unexpected panic received: err:", err) + } else if err == nil && test.shouldPanic { + t.Fatal("expected panic but received none") + } + }() + t.Parallel() + + err := Validate(test.sub) + if err != nil && test.shouldBeValid { + t.Fatal("unexpected error validating: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected subject to be invalid but wasn't") + } + }) + } +} diff --git a/nmagent/requests.go b/nmagent/requests.go index 1c0f60dd1f..15869ad8a5 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -2,14 +2,13 @@ package nmagent import ( "bytes" + "dnc/nmagent/internal" "encoding/json" "fmt" "io" "net/http" "strings" "unicode" - - "github.com/google/uuid" ) // Request {{{1 @@ -40,18 +39,20 @@ var _ Request = PutNetworkContainerRequest{} // PutNetworkContainerRequest is a collection of parameters necessary to create // a new network container type PutNetworkContainerRequest struct { - ID string `json:"networkContainerID"` // the id of the network container - VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet - Version uint64 `json:"version"` // the new network container version + ID string `json:"networkContainerID"` // the id of the network container + VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet + + // Version is the new network container version + Version uint64 `validate:"presence" json:"version"` // SubnetName is the name of the delegated subnet. This is used to // authenticate the request. The list of ipv4addresses must be contained in // the subnet's prefix. - SubnetName string `json:"subnetName"` + SubnetName string `validate:"presence" json:"subnetName"` // IPv4 addresses in the customer virtual network that will be assigned to // the interface. - IPv4Addrs []string `json:"ipV4Addresses"` + IPv4Addrs []string `validate:"presence" json:"ipV4Addresses"` Policies []Policy `json:"policies"` // policies applied to the network container @@ -60,7 +61,7 @@ type PutNetworkContainerRequest struct { VlanID int `json:"vlanId"` // VirtualNetworkID is the ID of the customer's virtual network - VirtualNetworkID string `json:"virtualNetworkId"` + VirtualNetworkID string `validate:"presence" json:"virtualNetworkId"` // AuthenticationToken is the base64 security token for the subnet containing // the Network Container addresses @@ -96,32 +97,7 @@ func (p PutNetworkContainerRequest) Path() string { // Validate ensures that all of the required parameters of the request have // been filled out properly prior to submission to NMAgent func (p PutNetworkContainerRequest) Validate() error { - var errs ValidationError - - if len(p.IPv4Addrs) == 0 { - errs.MissingFields = append(errs.MissingFields, "IPv4Addrs") - } - - if p.SubnetName == "" { - errs.MissingFields = append(errs.MissingFields, "SubnetName") - } - - // it's a little unclear as to whether a version value of "0" is actually - // legal. Given that this is the zero value of this field, and the - // documentation of NMAgent requires this to be a uint64, we'll consider "0" - // as unset and require it to be something else. - if p.Version == uint64(0) { - errs.MissingFields = append(errs.MissingFields, "Version") - } - - if p.VirtualNetworkID == "" { - errs.MissingFields = append(errs.MissingFields, "VirtualNetworkID") - } - - if errs.IsEmpty() { - return nil - } - return errs + return internal.Validate(p) } // Policy {{{2 @@ -170,7 +146,7 @@ func (p *Policy) UnmarshalJSON(in []byte) error { var _ Request = JoinNetworkRequest{} type JoinNetworkRequest struct { - NetworkID string `json:"-"` // the customer's VNet ID + NetworkID string `validate:"presence" json:"-"` // the customer's VNet ID } // Path constructs a URL path for invoking a JoinNetworkRequest using the @@ -192,12 +168,7 @@ func (j JoinNetworkRequest) Method() string { // Validate ensures that the provided parameters of the request are valid func (j JoinNetworkRequest) Validate() error { - // we need to be a little defensive, because there is no bad request response - // from NMAgent - if _, err := uuid.Parse(j.NetworkID); err != nil { - return fmt.Errorf("bad network ID %q: %w", j.NetworkID, err) - } - return nil + return internal.Validate(j) } // }}}1 @@ -209,12 +180,12 @@ var _ Request = DeleteContainerRequest{} // DeleteContainerRequest represents all information necessary to request that // NMAgent delete a particular network container type DeleteContainerRequest struct { - NCID string `json:"-"` // the Network Container ID + NCID string `validate:"presence" json:"-"` // the Network Container ID // PrimaryAddress is the primary customer address of the interface in the // management VNET - PrimaryAddress string `json:"-"` - AuthenticationToken string `json:"-"` + PrimaryAddress string `validate:"presence" json:"-"` + AuthenticationToken string `validate:"presence" json:"-"` } // Path returns the path for submitting a DeleteContainerRequest with @@ -237,25 +208,7 @@ func (d DeleteContainerRequest) Method() string { // Validate ensures that the DeleteContainerRequest has the correct information // to submit the request func (d DeleteContainerRequest) Validate() error { - errs := ValidationError{} - - if d.NCID == "" { - errs.MissingFields = append(errs.MissingFields, "NCID") - } - - if d.PrimaryAddress == "" { - errs.MissingFields = append(errs.MissingFields, "PrimaryAddress") - } - - if d.AuthenticationToken == "" { - errs.MissingFields = append(errs.MissingFields, "AuthenticationToken") - } - - if !errs.IsEmpty() { - return errs - } - - return nil + return internal.Validate(d) } // }}}1 @@ -267,7 +220,7 @@ var _ Request = GetNetworkConfigRequest{} // GetNetworkConfigRequest is a collection of necessary information for // submitting a request for a customer's network configuration type GetNetworkConfigRequest struct { - VNetID string `json:"-"` // the customer's virtual network ID + VNetID string `validate:"presence" json:"-"` // the customer's virtual network ID } // Path produces a URL path used to submit a request @@ -289,33 +242,7 @@ func (g GetNetworkConfigRequest) Method() string { // Validate ensures that the request is complete and the parameters are correct func (g GetNetworkConfigRequest) Validate() error { - errs := ValidationError{} - - if g.VNetID == "" { - errs.MissingFields = append(errs.MissingFields, "VNetID") - } - - if !errs.IsEmpty() { - return errs - } - - return nil -} - -// }}}1 - -// ValidationError {{{1 - -type ValidationError struct { - MissingFields []string -} - -func (v ValidationError) Error() string { - return fmt.Sprintf("missing fields: %s", strings.Join(v.MissingFields, ", ")) -} - -func (v ValidationError) IsEmpty() bool { - return len(v.MissingFields) == 0 + return internal.Validate(g) } // }}}1 diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index 63b78e03f6..3691c5033b 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -125,7 +125,7 @@ func TestJoinNetworkRequestValidate(t *testing.T) { { "invalid", nmagent.JoinNetworkRequest{ - NetworkID: "4815162342", + NetworkID: "", }, false, }, From 7eb129d43f4f6e94c02f11cce1cb94dadaec9d53 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Thu, 24 Mar 2022 12:26:10 -0400 Subject: [PATCH 20/44] Housekeeping: file renaming nmagent.go was really focused on the nmagent.Error type, so it made sense to rename the file to be more revealing. The same goes for internal.go and internal_test.go. Both of those were focused on retry logic. Also added a quick note explaining why client_helpers_test.go exists, since it can be a little subtle to those new to the language. --- nmagent/client_helpers_test.go | 4 ++++ nmagent/{nmagent.go => error.go} | 0 nmagent/internal/{internal.go => retry.go} | 0 nmagent/internal/{internal_test.go => retry_test.go} | 0 4 files changed, 4 insertions(+) rename nmagent/{nmagent.go => error.go} (100%) rename nmagent/internal/{internal.go => retry.go} (100%) rename nmagent/internal/{internal_test.go => retry_test.go} (100%) diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go index 74e0105c67..2e80fbc911 100644 --- a/nmagent/client_helpers_test.go +++ b/nmagent/client_helpers_test.go @@ -5,6 +5,10 @@ import ( "net/http" ) +// Note: this file exists to add two additional options with access to the +// internals of Client. It's a *_test file so that it is only compiled during +// tests. + // WithTransport allows a test to specify a particular http.RoundTripper for // use in testing scenarios func WithTransport(tripper http.RoundTripper) Option { diff --git a/nmagent/nmagent.go b/nmagent/error.go similarity index 100% rename from nmagent/nmagent.go rename to nmagent/error.go diff --git a/nmagent/internal/internal.go b/nmagent/internal/retry.go similarity index 100% rename from nmagent/internal/internal.go rename to nmagent/internal/retry.go diff --git a/nmagent/internal/internal_test.go b/nmagent/internal/retry_test.go similarity index 100% rename from nmagent/internal/internal_test.go rename to nmagent/internal/retry_test.go From a0b0fb046fb5d96b108c3cd8aebd6612201dd88e Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 29 Mar 2022 17:49:24 -0400 Subject: [PATCH 21/44] Remove Vim fold markers While this is nice for vim users, @ramiro-gamarra rightly pointed out that this is a maintenance burden for non-vim users with little benefit. Removing these to reduce the overhead. --- nmagent/requests.go | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/nmagent/requests.go b/nmagent/requests.go index 15869ad8a5..2d880fbde9 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -11,8 +11,6 @@ import ( "unicode" ) -// Request {{{1 - // Request represents an abstracted HTTP request, capable of validating itself, // producting a valid Path, Body, and its Method type Request interface { @@ -30,10 +28,6 @@ type Request interface { Method() string } -// }}}1 - -// PutNetworkContainerRequest {{{1 - var _ Request = PutNetworkContainerRequest{} // PutNetworkContainerRequest is a collection of parameters necessary to create @@ -100,8 +94,6 @@ func (p PutNetworkContainerRequest) Validate() error { return internal.Validate(p) } -// Policy {{{2 - type Policy struct { ID string Type string @@ -137,12 +129,6 @@ func (p *Policy) UnmarshalJSON(in []byte) error { return nil } -// }}}2 - -// }}}1 - -// JoinNetworkRequest {{{1 - var _ Request = JoinNetworkRequest{} type JoinNetworkRequest struct { @@ -171,10 +157,6 @@ func (j JoinNetworkRequest) Validate() error { return internal.Validate(j) } -// }}}1 - -// DeleteNetworkRequest {{{1 - var _ Request = DeleteContainerRequest{} // DeleteContainerRequest represents all information necessary to request that @@ -211,10 +193,6 @@ func (d DeleteContainerRequest) Validate() error { return internal.Validate(d) } -// }}}1 - -// GetNetworkConfigRequest {{{1 - var _ Request = GetNetworkConfigRequest{} // GetNetworkConfigRequest is a collection of necessary information for @@ -244,5 +222,3 @@ func (g GetNetworkConfigRequest) Method() string { func (g GetNetworkConfigRequest) Validate() error { return internal.Validate(g) } - -// }}}1 From 50a007a083be6f85167c6e886826d51cbb7c2fd5 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 29 Mar 2022 17:54:59 -0400 Subject: [PATCH 22/44] Set default scheme to http for nmagent client In practice, most communication for the nmagent client occurs over HTTP because it is intra-node traffic. While this is a useful option to have, the default should be useful for the common use case. --- nmagent/client.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 8f9352b2fd..61c3f5a59f 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -43,7 +43,7 @@ type Client struct { host string port uint16 - disableTLS bool + enableTLS bool // unauthorizedGracePeriod is the amount of time Unauthorized responses from // NMAgent will be tolerated and retried @@ -58,10 +58,11 @@ type Client struct { // client type Option func(*Client) -// InsecureDisableTLS is an option to disable TLS communications with NMAgent -func InsecureDisableTLS() Option { +// EnableTLS is an option to force all connections to NMAgent to occur over +// TLS. +func EnableTLS() Option { return func(c *Client) { - c.disableTLS = true + c.enableTLS = true } } @@ -202,10 +203,10 @@ func (c *Client) buildRequest(ctx context.Context, req Request) (*http.Request, } func (c *Client) scheme() string { - if c.disableTLS { - return "http" + if c.enableTLS { + return "https" } - return "https" + return "http" } // error constructs a NMAgent error while providing some information configured From f0739936599f818e1c7760eca1ec758675bfab03 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 29 Mar 2022 18:29:20 -0400 Subject: [PATCH 23/44] Change retry functions to return durations It was somewhat limiting that cooldown functions themselves would block. What was really interesting about them is that they calculated some time.Duration. Since passing 0 to time.Sleep causes it to return immediately, this has no impact on the AsFastAsPossible strategy Also improved some documentation and added a few examples at the request of @aegal --- nmagent/internal/retry.go | 43 ++++++++++------ nmagent/internal/retry_example_test.go | 68 ++++++++++++++++++++++++++ nmagent/internal/retry_test.go | 39 ++++++++------- 3 files changed, 117 insertions(+), 33 deletions(-) create mode 100644 nmagent/internal/retry_example_test.go diff --git a/nmagent/internal/retry.go b/nmagent/internal/retry.go index 2e57c404e5..c14d16bb59 100644 --- a/nmagent/internal/retry.go +++ b/nmagent/internal/retry.go @@ -8,11 +8,19 @@ import ( "time" ) +const ( + noDelay = 0 * time.Nanosecond +) + +// TemporaryError is an error that can indicate whether it may be resolved with +// another attempt type TemporaryError interface { error Temporary() bool } +// Retrier is a construct for attempting some operation multiple times with a +// configurable backoff strategy type Retrier struct { Cooldown CooldownFactory } @@ -33,10 +41,11 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { // check to see if it's temporary var tempErr TemporaryError if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { - err := cooldown() + delay, err := cooldown() if err != nil { return fmt.Errorf("sleeping during retry: %w", err) } + time.Sleep(delay) continue } @@ -49,28 +58,30 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { // CooldownFunc is a function that will block when called. It is intended for // use with retry logic. -type CooldownFunc func() error +type CooldownFunc func() (time.Duration, error) // CooldownFactory is a function that returns CooldownFuncs. It helps // CooldownFuncs dispose of any accumulated state so that they function // correctly upon successive uses. type CooldownFactory func() CooldownFunc +// Max provides a fixed limit for the number of times a subordinate cooldown +// function can be invoked func Max(limit int, factory CooldownFactory) CooldownFactory { return func() CooldownFunc { cooldown := factory() count := 0 - return func() error { + return func() (time.Duration, error) { if count >= limit { - return fmt.Errorf("maximum attempts reached (%d)", limit) + return noDelay, fmt.Errorf("maximum attempts reached (%d)", limit) } - err := cooldown() + delay, err := cooldown() if err != nil { - return err + return noDelay, err } count++ - return nil + return delay, nil } } } @@ -79,30 +90,30 @@ func Max(limit int, factory CooldownFactory) CooldownFactory { // logic to proceed as fast as possible. This is particularly useful in tests func AsFastAsPossible() CooldownFactory { return func() CooldownFunc { - return func() error { - return nil + return func() (time.Duration, error) { + return noDelay, nil } } } +// Exponential provides an exponential increase the the base interval provided func Exponential(interval time.Duration, base int) CooldownFactory { return func() CooldownFunc { count := 0 - return func() error { + return func() (time.Duration, error) { increment := math.Pow(float64(base), float64(count)) delay := interval.Nanoseconds() * int64(increment) - time.Sleep(time.Duration(delay)) count++ - return nil + return time.Duration(delay), nil } } } -func Fixed(interval time.Duration) CooldownFactory { +// Fixed produced the same delay value upon each invocation +func Fixed(delay time.Duration) CooldownFactory { return func() CooldownFunc { - return func() error { - time.Sleep(interval) - return nil + return func() (time.Duration, error) { + return delay, nil } } } diff --git a/nmagent/internal/retry_example_test.go b/nmagent/internal/retry_example_test.go new file mode 100644 index 0000000000..1d3f43cc1b --- /dev/null +++ b/nmagent/internal/retry_example_test.go @@ -0,0 +1,68 @@ +package internal + +import ( + "fmt" + "time" +) + +func ExampleExponential() { + // this example details the common case where the powers of 2 are desired + cooldown := Exponential(1*time.Millisecond, 2)() + + for i := 0; i < 5; i++ { + got, err := cooldown() + if err != nil { + fmt.Println("received error during cooldown: err:", err) + return + } + + fmt.Println(got) + } + + // Output: + // 1ms + // 2ms + // 4ms + // 8ms + // 16ms +} + +func ExampleFixed() { + cooldown := Fixed(10 * time.Millisecond)() + + for i := 0; i < 5; i++ { + got, err := cooldown() + if err != nil { + fmt.Println("unexpected error cooling down: err", err) + return + } + fmt.Println(got) + + // Output: + // 10ms + // 10ms + // 10ms + // 10ms + // 10ms + } +} + +func ExampleMax() { + cooldown := Max(4, Fixed(10*time.Millisecond))() + + for i := 0; i < 5; i++ { + got, err := cooldown() + if err != nil { + fmt.Println("error cooling down:", err) + break + } + fmt.Println(got) + + // Output: + // 10ms + // 10ms + // 10ms + // 10ms + // error cooling down: maximum attempts reached (4) + } +} diff --git a/nmagent/internal/retry_test.go b/nmagent/internal/retry_test.go index cb62214cf6..70fe00481f 100644 --- a/nmagent/internal/retry_test.go +++ b/nmagent/internal/retry_test.go @@ -95,35 +95,40 @@ func TestFixed(t *testing.T) { exp := 20 * time.Millisecond cooldown := Fixed(exp)() - start := time.Now() - cooldown() + got, err := cooldown() + if err != nil { + t.Fatal("unexpected error invoking cooldown: err:", err) + } - if got := time.Since(start); got < exp { + if got != exp { t.Fatal("unexpected sleep duration: exp:", exp, "got:", got) } } func TestExp(t *testing.T) { - interval := 10 * time.Millisecond + exp := 10 * time.Millisecond base := 2 - cooldown := Exponential(interval, base)() + cooldown := Exponential(exp, base)() - start := time.Now() - cooldown() + first, err := cooldown() + if err != nil { + t.Fatal("unexpected error invoking cooldown: err:", err) + } - first := time.Since(start) - if first < interval { - t.Fatal("unexpected sleep during first cooldown: exp:", interval, "got:", first) + if first != exp { + t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", first) } // ensure that the sleep increases - cooldown() + second, err := cooldown() + if err != nil { + t.Fatal("unexpected error on second invocation of cooldown: err:", err) + } - second := time.Since(start) if second < first { - t.Fatal("unexpected sleep during first cooldown: exp:", interval, "got:", second) + t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", second) } } @@ -133,16 +138,16 @@ func TestMax(t *testing.T) { // create a test sleep function fn := func() CooldownFunc { - return func() error { + return func() (time.Duration, error) { got++ - return nil + return 0 * time.Nanosecond, nil } } cooldown := Max(10, fn)() for i := 0; i < exp; i++ { - err := cooldown() + _, err := cooldown() if err != nil { t.Fatal("unexpected error from cooldown: err:", err) } @@ -153,7 +158,7 @@ func TestMax(t *testing.T) { } // attempt one more, we expect an error - err := cooldown() + _, err := cooldown() if err == nil { t.Errorf("expected an error after %d invocations but received none", exp+1) } From aeedeb9f4e3a9bdc64faf149c67e8e1df9ba94a8 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 29 Mar 2022 20:04:05 -0400 Subject: [PATCH 24/44] Rename imports The imports were incorrect because this client was moved from another module. --- nmagent/client.go | 2 +- nmagent/client_helpers_test.go | 3 ++- nmagent/client_test.go | 3 ++- nmagent/nmagent_test.go | 3 ++- nmagent/requests.go | 3 ++- nmagent/requests_test.go | 3 ++- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 61c3f5a59f..20fef71fb4 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -10,7 +10,7 @@ import ( "strconv" "time" - "dnc/nmagent/internal" + "github.com/Azure/azure-container-networking/nmagent/internal" ) // NewClient returns an initialized Client using the provided configuration diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go index 2e80fbc911..d07cfb659e 100644 --- a/nmagent/client_helpers_test.go +++ b/nmagent/client_helpers_test.go @@ -1,8 +1,9 @@ package nmagent import ( - "dnc/nmagent/internal" "net/http" + + "github.com/Azure/azure-container-networking/nmagent/internal" ) // Note: this file exists to add two additional options with access to the diff --git a/nmagent/client_test.go b/nmagent/client_test.go index d24ab78aad..1452320dcf 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -2,7 +2,6 @@ package nmagent_test import ( "context" - "dnc/nmagent" "encoding/json" "fmt" "net/http" @@ -10,6 +9,8 @@ import ( "testing" "time" + "github.com/Azure/azure-container-networking/nmagent" + "github.com/google/go-cmp/cmp" ) diff --git a/nmagent/nmagent_test.go b/nmagent/nmagent_test.go index f77cb40f01..abf2b19971 100644 --- a/nmagent/nmagent_test.go +++ b/nmagent/nmagent_test.go @@ -1,10 +1,11 @@ package nmagent_test import ( - "dnc/nmagent" "net/http" "testing" "time" + + "github.com/Azure/azure-container-networking/nmagent" ) func TestErrorTemp(t *testing.T) { diff --git a/nmagent/requests.go b/nmagent/requests.go index 2d880fbde9..2cb47467ae 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -2,13 +2,14 @@ package nmagent import ( "bytes" - "dnc/nmagent/internal" "encoding/json" "fmt" "io" "net/http" "strings" "unicode" + + "github.com/Azure/azure-container-networking/nmagent/internal" ) // Request represents an abstracted HTTP request, capable of validating itself, diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index 3691c5033b..28dc119930 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -1,10 +1,11 @@ package nmagent_test import ( - "dnc/nmagent" "encoding/json" "testing" + "github.com/Azure/azure-container-networking/nmagent" + "github.com/google/go-cmp/cmp" ) From f942f3b7c15bb25cd820a1952f675fb86d653fcc Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Thu, 31 Mar 2022 16:19:05 -0400 Subject: [PATCH 25/44] Duplicate the request in wireserver transport Upon closer reading of the RoundTripper documentation, it's clear that RoundTrippers should not modify the request. While effort has been made to reset any mutations, this is still, technically, modifying the request. Instead, this duplicates the request immediately and re-uses the context that was provided to it. --- nmagent/internal/wireserver.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index d3d6a96a4a..2554ed85c3 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -44,9 +44,15 @@ type WireserverTransport struct { // RoundTrip executes arbitrary HTTP requests against Wireserver while applying // the necessary transformation rules to make such requests acceptable to // Wireserver -func (w *WireserverTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, error) { + // RoundTrippers are not allowed to modify the request, so we clone it here. + // We need to extract the context from the request first since this is _not_ + // cloned. The dependent Wireserver request should have the same deadline and + // cancellation properties as the inbound request though, hence the reuse. + ctx := inReq.Context() + req := inReq.Clone(ctx) + // the original path of the request must be prefixed with wireserver's path - origPath := req.URL.Path path := WirePrefix if req.URL.Path != "" { path += req.URL.Path[1:] @@ -63,18 +69,10 @@ func (w *WireserverTransport) RoundTrip(req *http.Request) (*http.Response, erro } req.URL.Path = path - // ensure that nothing has changed from the caller's perspective by resetting - // the URL - defer func() { - req.URL.Path = origPath - }() // wireserver cannot tolerate PUT requests, so it's necessary to transform those to POSTs if req.Method == http.MethodPut { req.Method = http.MethodPost - defer func() { - req.Method = http.MethodPut - }() } // all POST requests (and by extension, PUT) must have a non-nil body From 005ca05118b72a7fe0ecae274b45e0e45e693f7f Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Thu, 31 Mar 2022 16:21:41 -0400 Subject: [PATCH 26/44] Drain and close http ResponseBodies It's not entirely clear whether this is needed or not. The documentation for http.(*Client).Do indicates that this is necessary, but experimentation in the community has found that this is maybe not 100% necessary (calling `Close` on the Body appears to be enough). The only harm that can come from this is if Wireserver hands back enormous responses, which is not the case--these responses are fairly small. --- nmagent/internal/wireserver.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 2554ed85c3..66fc3b0231 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -85,8 +85,12 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er if err != nil { return resp, err } - // we want to close this because we're going to replace it - defer resp.Body.Close() + // we need to take the body as an argument to this drain & close so we can + // bind to this specific instance because we intend to replace it + defer func(body io.ReadCloser) { + io.Copy(io.Discard, body) + body.Close() + }(resp.Body) if resp.StatusCode != http.StatusOK { return resp, nil From 244fa4dbdeb753172dde741bc53fe79fd1ad02e1 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 6 Apr 2022 10:48:30 -0400 Subject: [PATCH 27/44] Capture unexpected content from Wireserver During certain error cases, Wireserver may return XML. This XML is useful in debugging, so we want to capture it in the error and surface it appropriately. It's unclear whether Wireserver notes the Content-Type, so we use Go's content type detection to figure out what the type of the response is and clean it up to pass along to the NMAgent Client. This also introduces a new ContentError which semantically represents the situation where we were given a content type that we didn't expect. --- nmagent/client.go | 5 +++ nmagent/error.go | 40 ++++++++++++++++++ nmagent/internal/internal.go | 10 +++++ nmagent/internal/wireserver.go | 70 +++++++++++++++++++++++++------ nmagent/nmagent_test.go | 76 ++++++++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+), 12 deletions(-) create mode 100644 nmagent/internal/internal.go diff --git a/nmagent/client.go b/nmagent/client.go index 20fef71fb4..61b8336680 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -122,6 +122,11 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon return c.error(time.Since(requestStart), resp.StatusCode) } + ct := resp.Header.Get(internal.HeaderContentType) + if ct != internal.MimeJSON { + return NewContentError(ct, resp.Body, resp.ContentLength) + } + err = json.NewDecoder(resp.Body).Decode(&out) if err != nil { return fmt.Errorf("decoding json response: %w", err) diff --git a/nmagent/error.go b/nmagent/error.go index c68d72d6b3..5b9d8a678f 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -1,11 +1,51 @@ package nmagent import ( + "errors" "fmt" + "io" "net/http" "time" + + "github.com/Azure/azure-container-networking/nmagent/internal" ) +// ContentError is encountered when an unexpected content type is obtained from +// NMAgent +type ContentError struct { + Type string // the mime type of the content received + Body []byte // the received body +} + +func (c ContentError) Error() string { + if c.Type == internal.MimeOctetStream { + return fmt.Sprintf("unexpected content type %q: body length: %d", c.Type, len(c.Body)) + } + return fmt.Sprintf("unexpected content type %q: body: %s", c.Type, c.Body) +} + +// NewContentError creates a ContentError from a provided reader and limit. +func NewContentError(contentType string, in io.Reader, limit int64) error { + out := ContentError{ + Type: contentType, + Body: make([]byte, limit), + } + + bodyReader := io.LimitReader(in, limit) + + read, err := io.ReadFull(bodyReader, out.Body) + earlyEOF := errors.Is(err, io.ErrUnexpectedEOF) + if err != nil && !earlyEOF { + return fmt.Errorf("reading unexpected content body: %w", err) + } + + if earlyEOF { + out.Body = out.Body[:read] + } + + return out +} + // Error is a aberrent condition encountered when interacting with the NMAgent // API type Error struct { diff --git a/nmagent/internal/internal.go b/nmagent/internal/internal.go new file mode 100644 index 0000000000..88ec948672 --- /dev/null +++ b/nmagent/internal/internal.go @@ -0,0 +1,10 @@ +package internal + +const ( + HeaderContentType = "Content-Type" +) + +const ( + MimeJSON = "application/json" + MimeOctetStream = "application/octet-stream" +) diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 66fc3b0231..40cc925f3e 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -3,6 +3,7 @@ package internal import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -10,7 +11,20 @@ import ( "strings" ) -const WirePrefix string = "/machine/plugins/?comp=nmagent&type=" +const ( + _ int64 = 1 << (10 * iota) + kilobyte + megabyte +) + +const ( + WirePrefix string = "/machine/plugins/?comp=nmagent&type=" + + // DefaultBufferSize is the maximum number of bytes read from Wireserver in + // the event that no Content-Length is provided. The responses are relatively + // small, so the smallest page size should be sufficient + DefaultBufferSize int64 = 4 * kilobyte +) var _ http.RoundTripper = &WireserverTransport{} @@ -70,7 +84,8 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er req.URL.Path = path - // wireserver cannot tolerate PUT requests, so it's necessary to transform those to POSTs + // wireserver cannot tolerate PUT requests, so it's necessary to transform + // those to POSTs if req.Method == http.MethodPut { req.Method = http.MethodPost } @@ -85,24 +100,55 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er if err != nil { return resp, err } - // we need to take the body as an argument to this drain & close so we can - // bind to this specific instance because we intend to replace it + + if resp.StatusCode != http.StatusOK { + // something happened at Wireserver, so we should just hand this back up + return resp, nil + } + + // at this point we're definitely going to modify the body, so we want to + // make sure we close the original request's body, since we're going to + // replace it defer func(body io.ReadCloser) { - io.Copy(io.Discard, body) body.Close() }(resp.Body) - if resp.StatusCode != http.StatusOK { - return resp, nil + // buffer the entire response from Wireserver + clen := resp.ContentLength + if clen < 0 { + clen = DefaultBufferSize + } + + body := make([]byte, clen) + bodyReader := io.LimitReader(resp.Body, clen) + + numRead, err := io.ReadFull(bodyReader, body) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { + return nil, fmt.Errorf("reading response from wireserver: %w", err) } + // it's entirely possible at this point that we read less than we allocated, + // so trim the slice back for decoding + body = body[:numRead] + + // set the content length properly in case it wasn't set. If it was, this is + // effectively a no-op + resp.ContentLength = int64(numRead) - // correct the HTTP status returned from wireserver + // it's unclear whether Wireserver sets Content-Type appropriately, so we + // attempt to decode it first and surface it otherwise var wsResp WireserverResponse - err = json.NewDecoder(resp.Body).Decode(&wsResp) + err = json.Unmarshal(body, &wsResp) if err != nil { - return resp, fmt.Errorf("decoding json response from wireserver: %w", err) + // probably not JSON, so figure out what it is, pack it up, and surface it + // unmodified + resp.Header.Set(HeaderContentType, http.DetectContentType(body)) + resp.Body = io.NopCloser(bytes.NewReader(body)) + return resp, nil } + // we know that it's JSON now, so communicate that upwards + resp.Header.Set(HeaderContentType, MimeJSON) + // set the response status code with the *real* status code realCode, err := wsResp.StatusCode() if err != nil { @@ -114,12 +160,12 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er // re-encode the body and re-attach it to the response delete(wsResp, "httpStatusCode") // TODO(timraymond): concern of the response - body, err := json.Marshal(wsResp) + outBody, err := json.Marshal(wsResp) if err != nil { return resp, fmt.Errorf("re-encoding json response from wireserver: %w", err) } - resp.Body = io.NopCloser(bytes.NewReader(body)) + resp.Body = io.NopCloser(bytes.NewReader(outBody)) return resp, nil } diff --git a/nmagent/nmagent_test.go b/nmagent/nmagent_test.go index abf2b19971..79068a3e41 100644 --- a/nmagent/nmagent_test.go +++ b/nmagent/nmagent_test.go @@ -1,7 +1,11 @@ package nmagent_test import ( + "bytes" + "errors" + "io" "net/http" + "strings" "testing" "time" @@ -78,3 +82,75 @@ func TestErrorTemp(t *testing.T) { }) } } + +func TestContentErrorNew(t *testing.T) { + errTests := []struct { + name string + body io.Reader + limit int64 + contentType string + exp string + shouldMakeErr bool + }{ + { + "empty", + strings.NewReader(""), + 0, + "text/plain", + "unexpected content type \"text/plain\": body: ", + true, + }, + { + "happy path", + strings.NewReader("random text"), + 11, + "text/plain", + "unexpected content type \"text/plain\": body: random text", + true, + }, + { + // if the body is an octet stream, it's entirely possible that it's + // unprintable garbage. This ensures that we just print the length + "octets", + bytes.NewReader([]byte{0xde, 0xad, 0xbe, 0xef}), + 4, + "application/octet-stream", + "unexpected content type \"application/octet-stream\": body length: 4", + true, + }, + { + // even if the length is wrong, we still want to return as much data as + // we can for debugging + "wrong len", + bytes.NewReader([]byte{0xde, 0xad, 0xbe, 0xef}), + 8, + "application/octet-stream", + "unexpected content type \"application/octet-stream\": body length: 4", + true, + }, + } + + for _, test := range errTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := nmagent.NewContentError(test.contentType, test.body, test.limit) + + var e nmagent.ContentError + wasContentErr := errors.As(err, &e) + if !wasContentErr && test.shouldMakeErr { + t.Fatalf("error was not a ContentError") + } + + if wasContentErr && !test.shouldMakeErr { + t.Fatalf("received a ContentError when it was not expected") + } + + got := err.Error() + if got != test.exp { + t.Error("unexpected error message: got:", got, "exp:", test.exp) + } + }) + } +} From 20fb63a5babd3fcd019e2aedcc45b10063299443 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 6 Apr 2022 10:50:56 -0400 Subject: [PATCH 28/44] Don't return a response with an error in RoundTrip The http.Client complains if you return a non-nil response and an error as well. This fixes one instance where that was happening. --- nmagent/internal/wireserver.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 40cc925f3e..1d1de6aa9b 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -98,7 +98,7 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er // execute the request to the downstream transport resp, err := w.Transport.RoundTrip(req) if err != nil { - return resp, err + return nil, err } if resp.StatusCode != http.StatusOK { From 9c7191cd3f633dab9f8b43dde4b5e25cac7d1097 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Wed, 6 Apr 2022 10:53:55 -0400 Subject: [PATCH 29/44] Remove extra vim folding marks These were intended to be removed in another commit, but there were some stragglers. --- nmagent/internal/validate.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nmagent/internal/validate.go b/nmagent/internal/validate.go index ee28b69b06..71ab90bdd4 100644 --- a/nmagent/internal/validate.go +++ b/nmagent/internal/validate.go @@ -6,8 +6,6 @@ import ( "strings" ) -// ValidationError {{{1 - type ValidationError struct { MissingFields []string } @@ -20,8 +18,6 @@ func (v ValidationError) IsEmpty() bool { return len(v.MissingFields) == 0 } -// }}}1 - // Validate searches for validate struct tags and performs the validations // requested by them func Validate(obj interface{}) error { From 3910ae0338009a9441af0bdc4c37ea3710684feb Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 11 Apr 2022 14:29:35 -0400 Subject: [PATCH 30/44] Replace fmt.Errorf with errors.Wrap Even though fmt.Errorf provides an official error-wrapping solution for Go, we have made the decision to use errors.Wrap for its stack collection support. This integrates well with Uber's Zap logger, which we also plan to integrate. --- nmagent/client.go | 25 +++++++++++++------------ nmagent/client_test.go | 7 ++++--- nmagent/error.go | 4 +++- nmagent/internal/retry.go | 4 +++- nmagent/internal/wireserver.go | 12 +++++++----- nmagent/internal/wireserver_test.go | 4 ++-- nmagent/requests.go | 6 ++++-- 7 files changed, 36 insertions(+), 26 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 61b8336680..e8bc885f21 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -3,13 +3,14 @@ package nmagent import ( "context" "encoding/json" - "fmt" "net" "net/http" "net/url" "strconv" "time" + "github.com/pkg/errors" + "github.com/Azure/azure-container-networking/nmagent/internal" ) @@ -80,13 +81,13 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error req, err := c.buildRequest(ctx, jnr) if err != nil { - return fmt.Errorf("building request: %w", err) + return errors.Wrap(err, "building request") } err = c.retrier.Do(ctx, func() error { resp, err := c.httpClient.Do(req) if err != nil { - return fmt.Errorf("executing request: %w", err) + return errors.Wrap(err, "executing request") } defer resp.Body.Close() @@ -108,13 +109,13 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon req, err := c.buildRequest(ctx, gncr) if err != nil { - return out, fmt.Errorf("building request: %w", err) + return out, errors.Wrap(err, "building request") } err = c.retrier.Do(ctx, func() error { resp, err := c.httpClient.Do(req) if err != nil { - return fmt.Errorf("executing http request to: %w", err) + return errors.Wrap(err, "executing http request to") } defer resp.Body.Close() @@ -129,7 +130,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon err = json.NewDecoder(resp.Body).Decode(&out) if err != nil { - return fmt.Errorf("decoding json response: %w", err) + return errors.Wrap(err, "decoding json response") } return nil @@ -145,12 +146,12 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr PutNetworkContain req, err := c.buildRequest(ctx, pncr) if err != nil { - return fmt.Errorf("building request: %w", err) + return errors.Wrap(err, "building request") } resp, err := c.httpClient.Do(req) if err != nil { - return fmt.Errorf("submitting request: %w", err) + return errors.Wrap(err, "submitting request") } defer resp.Body.Close() @@ -167,12 +168,12 @@ func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainer req, err := c.buildRequest(ctx, dcr) if err != nil { - return fmt.Errorf("building request: %w", err) + return errors.Wrap(err, "building request") } resp, err := c.httpClient.Do(req) if err != nil { - return fmt.Errorf("submitting request: %w", err) + return errors.Wrap(err, "submitting request") } defer resp.Body.Close() @@ -190,7 +191,7 @@ func (c *Client) hostPort() string { func (c *Client) buildRequest(ctx context.Context, req Request) (*http.Request, error) { if err := req.Validate(); err != nil { - return nil, fmt.Errorf("validating request: %w", err) + return nil, errors.Wrap(err, "validating request") } fullURL := &url.URL{ @@ -201,7 +202,7 @@ func (c *Client) buildRequest(ctx context.Context, req Request) (*http.Request, body, err := req.Body() if err != nil { - return nil, fmt.Errorf("retrieving request body: %w", err) + return nil, errors.Wrap(err, "retrieving request body") } return http.NewRequestWithContext(ctx, req.Method(), fullURL.String(), body) diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 1452320dcf..b5642916df 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -9,9 +9,10 @@ import ( "testing" "time" - "github.com/Azure/azure-container-networking/nmagent" - "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" + + "github.com/Azure/azure-container-networking/nmagent" ) var _ http.RoundTripper = &TestTripper{} @@ -230,7 +231,7 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { rr.WriteHeader(http.StatusOK) err := json.NewEncoder(rr).Encode(&test.expVNet) if err != nil { - return nil, fmt.Errorf("encoding response: %w", err) + return nil, errors.Wrap(err, "encoding response") } return rr.Result(), nil diff --git a/nmagent/error.go b/nmagent/error.go index 5b9d8a678f..31d5595bb3 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -7,6 +7,8 @@ import ( "net/http" "time" + pkgerrors "github.com/pkg/errors" + "github.com/Azure/azure-container-networking/nmagent/internal" ) @@ -36,7 +38,7 @@ func NewContentError(contentType string, in io.Reader, limit int64) error { read, err := io.ReadFull(bodyReader, out.Body) earlyEOF := errors.Is(err, io.ErrUnexpectedEOF) if err != nil && !earlyEOF { - return fmt.Errorf("reading unexpected content body: %w", err) + return pkgerrors.Wrap(err, "reading unexpected content body") } if earlyEOF { diff --git a/nmagent/internal/retry.go b/nmagent/internal/retry.go index c14d16bb59..339f8c9de5 100644 --- a/nmagent/internal/retry.go +++ b/nmagent/internal/retry.go @@ -6,6 +6,8 @@ import ( "fmt" "math" "time" + + pkgerrors "github.com/pkg/errors" ) const ( @@ -43,7 +45,7 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { delay, err := cooldown() if err != nil { - return fmt.Errorf("sleeping during retry: %w", err) + return pkgerrors.Wrap(err, "sleeping during retry") } time.Sleep(delay) continue diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 1d1de6aa9b..85c6fa3273 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -9,6 +9,8 @@ import ( "net/http" "strconv" "strings" + + pkgerrors "github.com/pkg/errors" ) const ( @@ -37,11 +39,11 @@ func (w WireserverResponse) StatusCode() (int, error) { var statusStr string err := json.Unmarshal(status, &statusStr) if err != nil { - return 0, fmt.Errorf("unmarshaling httpStatusCode from Wireserver: %w", err) + return 0, pkgerrors.Wrap(err, "unmarshaling httpStatusCode from Wireserver") } if code, err := strconv.Atoi(statusStr); err != nil { - return code, fmt.Errorf("parsing http status code from wireserver: %w", err) + return code, pkgerrors.Wrap(err, "parsing http status code from wireserver") } else { return code, nil } @@ -124,7 +126,7 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er numRead, err := io.ReadFull(bodyReader, body) if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { - return nil, fmt.Errorf("reading response from wireserver: %w", err) + return nil, pkgerrors.Wrap(err, "reading response from wireserver") } // it's entirely possible at this point that we read less than we allocated, // so trim the slice back for decoding @@ -152,7 +154,7 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er // set the response status code with the *real* status code realCode, err := wsResp.StatusCode() if err != nil { - return resp, fmt.Errorf("retrieving status code from wireserver response: %w", err) + return resp, pkgerrors.Wrap(err, "retrieving status code from wireserver response") } resp.StatusCode = realCode @@ -162,7 +164,7 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er outBody, err := json.Marshal(wsResp) if err != nil { - return resp, fmt.Errorf("re-encoding json response from wireserver: %w", err) + return resp, pkgerrors.Wrap(err, "re-encoding json response from wireserver") } resp.Body = io.NopCloser(bytes.NewReader(outBody)) diff --git a/nmagent/internal/wireserver_test.go b/nmagent/internal/wireserver_test.go index f78abc5204..94594ce99d 100644 --- a/nmagent/internal/wireserver_test.go +++ b/nmagent/internal/wireserver_test.go @@ -2,12 +2,12 @@ package internal import ( "encoding/json" - "fmt" "net/http" "net/http/httptest" "testing" "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" ) type TestTripper struct { @@ -136,7 +136,7 @@ func TestWireserverTransportStatusTransform(t *testing.T) { err := json.NewEncoder(rr).Encode(&test.response) if err != nil { - return nil, fmt.Errorf("encoding json response: %w", err) + return nil, errors.Wrap(err, "encoding json response") } return rr.Result(), nil diff --git a/nmagent/requests.go b/nmagent/requests.go index 2cb47467ae..3734ee1a52 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -9,6 +9,8 @@ import ( "strings" "unicode" + "github.com/pkg/errors" + "github.com/Azure/azure-container-networking/nmagent/internal" ) @@ -72,7 +74,7 @@ type PutNetworkContainerRequest struct { func (p PutNetworkContainerRequest) Body() (io.Reader, error) { body, err := json.Marshal(p) if err != nil { - return nil, fmt.Errorf("marshaling PutNetworkContainerRequest: %w", err) + return nil, errors.Wrap(err, "marshaling PutNetworkContainerRequest") } return bytes.NewReader(body), nil @@ -116,7 +118,7 @@ func (p *Policy) UnmarshalJSON(in []byte) error { var raw string err := json.Unmarshal(in, &raw) if err != nil { - return fmt.Errorf("decoding policy: %w", err) + return errors.Wrap(err, "decoding policy") } parts := strings.Split(raw, ",") From 6f79a164965fa4eb482c4b66f0002d3853ecf070 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 11 Apr 2022 15:19:31 -0400 Subject: [PATCH 31/44] Use Config struct instead of functional Options We determined that a Config struct would be more obvious than the functional options in a debugging scenario. --- nmagent/client.go | 25 +++----- nmagent/client_helpers_test.go | 30 ++++----- nmagent/client_test.go | 111 ++++----------------------------- nmagent/config.go | 36 +++++++++++ nmagent/config_test.go | 59 ++++++++++++++++++ 5 files changed, 127 insertions(+), 134 deletions(-) create mode 100644 nmagent/config.go create mode 100644 nmagent/config_test.go diff --git a/nmagent/client.go b/nmagent/client.go index e8bc885f21..77b5a45e7e 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -15,25 +15,26 @@ import ( ) // NewClient returns an initialized Client using the provided configuration -func NewClient(host string, port uint16, opts ...Option) *Client { +func NewClient(c Config) (*Client, error) { + if err := c.Validate(); err != nil { + return nil, errors.Wrap(err, "validating config") + } + client := &Client{ httpClient: &http.Client{ Transport: &internal.WireserverTransport{ Transport: http.DefaultTransport, }, }, - host: host, - port: port, + host: c.Host, + port: c.Port, + enableTLS: c.UseTLS, retrier: internal.Retrier{ Cooldown: internal.Exponential(1*time.Second, 2), }, } - for _, opt := range opts { - opt(client) - } - - return client + return client, nil } // Client is an agent for exchanging information with NMAgent @@ -59,14 +60,6 @@ type Client struct { // client type Option func(*Client) -// EnableTLS is an option to force all connections to NMAgent to occur over -// TLS. -func EnableTLS() Option { - return func(c *Client) { - c.enableTLS = true - } -} - // WithUnauthorizedGracePeriod is an option to treat Unauthorized (401) // responses from NMAgent as temporary errors for a configurable amount of time func WithUnauthorizedGracePeriod(grace time.Duration) Option { diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go index d07cfb659e..4f98959742 100644 --- a/nmagent/client_helpers_test.go +++ b/nmagent/client_helpers_test.go @@ -6,27 +6,19 @@ import ( "github.com/Azure/azure-container-networking/nmagent/internal" ) -// Note: this file exists to add two additional options with access to the -// internals of Client. It's a *_test file so that it is only compiled during -// tests. - -// WithTransport allows a test to specify a particular http.RoundTripper for -// use in testing scenarios -func WithTransport(tripper http.RoundTripper) Option { - return func(c *Client) { - c.httpClient = &http.Client{ +// NewTestClient is a factory function available in tests only for creating +// NMAgent clients with a mock transport +func NewTestClient(transport http.RoundTripper) *Client { + return &Client{ + httpClient: &http.Client{ Transport: &internal.WireserverTransport{ - Transport: tripper, + Transport: transport, }, - } - } -} - -// NoBackoff disables exponential backoff in the client -func NoBackoff() Option { - return func(c *Client) { - c.retrier = internal.Retrier{ + }, + host: "localhost", + port: 12345, + retrier: internal.Retrier{ Cooldown: internal.AsFastAsPossible(), - } + }, } } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index b5642916df..55c977ce58 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" @@ -65,7 +64,7 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { // create a client var got string - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ + client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.Path rr := httptest.NewRecorder() @@ -73,7 +72,7 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, - }), nmagent.NoBackoff()) + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -109,7 +108,7 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { invocations := 0 exp := 10 - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ + client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if invocations < exp { @@ -121,51 +120,7 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { rr.Write([]byte(`{"httpStatusCode": "200"}`)) return rr.Result(), nil }, - }), nmagent.NoBackoff()) - - // if the test provides a timeout, use it in the context - var ctx context.Context - if deadline, ok := t.Deadline(); ok { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), deadline) - defer cancel() - } else { - ctx = context.Background() - } - - // attempt to join network - err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{"00000000-0000-0000-0000-000000000000"}) - if err != nil { - t.Fatal("unexpected error: err:", err) - } - - if invocations != exp { - t.Error("client did not make the expected number of API calls: got:", invocations, "exp:", exp) - } -} - -// TODO(timraymond): this is super repetitive (see the retry test) -func TestNMAgentClientJoinNetworkUnauthorized(t *testing.T) { - t.Parallel() - - // we want to ensure that the client will automatically follow up with - // NMAgent, so we want to track the number of requests that it makes - invocations := 0 - exp := 10 - - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - if invocations < exp { - rr.WriteHeader(http.StatusUnauthorized) - invocations++ - } else { - rr.WriteHeader(http.StatusOK) - } - rr.Write([]byte(`{"httpStatusCode": "200"}`)) - return rr.Result(), nil - }, - }), nmagent.WithUnauthorizedGracePeriod(1*time.Minute), nmagent.NoBackoff()) + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -224,7 +179,7 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { t.Parallel() var got string - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ + client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() got = req.URL.Path @@ -236,7 +191,7 @@ func TestNMAgentGetNetworkConfig(t *testing.T) { return rr.Result(), nil }, - }), nmagent.NoBackoff()) + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -282,7 +237,7 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { count := 0 exp := 10 - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ + client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() if count < exp { @@ -296,49 +251,7 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { rr.Write([]byte(`{"httpStatusCode": "200"}`)) return rr.Result(), nil }, - }), nmagent.NoBackoff()) - - // if the test provides a timeout, use it in the context - var ctx context.Context - if deadline, ok := t.Deadline(); ok { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(context.Background(), deadline) - defer cancel() - } else { - ctx = context.Background() - } - - _, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{"00000000-0000-0000-0000-000000000000"}) - if err != nil { - t.Fatal("unexpected error: err:", err) - } - - if count != exp { - t.Error("unexpected number of API calls: exp:", exp, "got:", count) - } -} - -func TestNMAgentGetNetworkConfigUnauthorized(t *testing.T) { - t.Parallel() - - count := 0 - exp := 10 - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ - RoundTripF: func(req *http.Request) (*http.Response, error) { - rr := httptest.NewRecorder() - if count < exp { - rr.WriteHeader(http.StatusUnauthorized) - count++ - } else { - rr.WriteHeader(http.StatusOK) - } - - // we still need a fake response - rr.Write([]byte(`{"httpStatusCode": "200"}`)) - - return rr.Result(), nil - }, - }), nmagent.WithUnauthorizedGracePeriod(1*time.Minute), nmagent.NoBackoff()) + }) // if the test provides a timeout, use it in the context var ctx context.Context @@ -401,7 +314,7 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { t.Parallel() didCall := false - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ + client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() rr.Write([]byte(`{"httpStatusCode": "200"}`)) @@ -409,7 +322,7 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { didCall = true return rr.Result(), nil }, - }), nmagent.NoBackoff()) + }) err := client.PutNetworkContainer(context.TODO(), test.req) if err != nil && !test.shouldErr { @@ -454,14 +367,14 @@ func TestNMAgentDeleteNC(t *testing.T) { for _, test := range deleteTests { test := test t.Run(test.name, func(t *testing.T) { - client := nmagent.NewClient("localhost", 8080, nmagent.WithTransport(&TestTripper{ + client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.Path rr := httptest.NewRecorder() rr.Write([]byte(`{"httpStatusCode": "200"}`)) return rr.Result(), nil }, - }), nmagent.NoBackoff()) + }) err := client.DeleteNetworkContainer(context.TODO(), test.req) if err != nil && !test.shouldErr { diff --git a/nmagent/config.go b/nmagent/config.go new file mode 100644 index 0000000000..5514b8eb81 --- /dev/null +++ b/nmagent/config.go @@ -0,0 +1,36 @@ +package nmagent + +import "github.com/Azure/azure-container-networking/nmagent/internal" + +// Config is a configuration for an NMAgent Client +type Config struct { + ///////////////////// + // Required Config // + ///////////////////// + Host string // the host the client will connect to + Port uint16 // the port the client will connect to + + ///////////////////// + // Optional Config // + ///////////////////// + UseTLS bool // forces all connections to use TLS +} + +// Validate reports whether this configuration is a valid configuration for a +// client +func (c Config) Validate() error { + err := internal.ValidationError{} + + if c.Host == "" { + err.MissingFields = append(err.MissingFields, "Host") + } + + if c.Port == 0 { + err.MissingFields = append(err.MissingFields, "Port") + } + + if err.IsEmpty() { + return nil + } + return err +} diff --git a/nmagent/config_test.go b/nmagent/config_test.go new file mode 100644 index 0000000000..ed9b65f4f9 --- /dev/null +++ b/nmagent/config_test.go @@ -0,0 +1,59 @@ +package nmagent_test + +import ( + "testing" + + "github.com/Azure/azure-container-networking/nmagent" +) + +func TestConfig(t *testing.T) { + configTests := []struct { + name string + config nmagent.Config + expValid bool + }{ + { + "empty", + nmagent.Config{}, + false, + }, + { + "missing port", + nmagent.Config{ + Host: "localhost", + }, + false, + }, + { + "missing host", + nmagent.Config{ + Port: 12345, + }, + false, + }, + { + "complete", + nmagent.Config{ + Host: "localhost", + Port: 12345, + }, + true, + }, + } + + for _, test := range configTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.config.Validate() + if err != nil && test.expValid { + t.Fatal("expected config to be valid but wasnt: err:", err) + } + + if err == nil && !test.expValid { + t.Fatal("expected config to be invalid but wasn't") + } + }) + } +} From 5cb39d6a58011a86b413f6dc936dceaf1019be1f Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 11 Apr 2022 15:40:20 -0400 Subject: [PATCH 32/44] Remove validation struct tags The validation struct tags were deemed too magical and thus removed in favor of straight-line validation logic. --- nmagent/internal/validate.go | 33 ----------- nmagent/internal/validate_test.go | 92 ------------------------------- nmagent/requests.go | 78 ++++++++++++++++++++++---- 3 files changed, 67 insertions(+), 136 deletions(-) delete mode 100644 nmagent/internal/validate_test.go diff --git a/nmagent/internal/validate.go b/nmagent/internal/validate.go index 71ab90bdd4..de6ebfa221 100644 --- a/nmagent/internal/validate.go +++ b/nmagent/internal/validate.go @@ -2,7 +2,6 @@ package internal import ( "fmt" - "reflect" "strings" ) @@ -17,35 +16,3 @@ func (v ValidationError) Error() string { func (v ValidationError) IsEmpty() bool { return len(v.MissingFields) == 0 } - -// Validate searches for validate struct tags and performs the validations -// requested by them -func Validate(obj interface{}) error { - errs := ValidationError{} - - val := reflect.ValueOf(obj) - typ := reflect.TypeOf(obj) - - for i := 0; i < val.NumField(); i++ { - fieldVal := val.Field(i) - fieldTyp := typ.Field(i) - - op := fieldTyp.Tag.Get("validate") - switch op { - case "presence": - if fieldVal.Kind() == reflect.Slice { - if fieldVal.Len() == 0 { - errs.MissingFields = append(errs.MissingFields, fieldTyp.Name) - } - } else if fieldVal.IsZero() { - errs.MissingFields = append(errs.MissingFields, fieldTyp.Name) - } - } - } - - if errs.IsEmpty() { - return nil - } - - return errs -} diff --git a/nmagent/internal/validate_test.go b/nmagent/internal/validate_test.go deleted file mode 100644 index 884774d94b..0000000000 --- a/nmagent/internal/validate_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package internal - -import "testing" - -func TestValidate(t *testing.T) { - validateTests := []struct { - name string - sub interface{} - shouldBeValid bool - shouldPanic bool - }{ - { - "empty", - struct{}{}, - true, - false, - }, - { - "no tags", - struct { - Foo string - }{""}, - true, - false, - }, - { - "presence", - struct { - Foo string `validate:"presence"` - }{"hi"}, - true, - false, - }, - { - "presence empty", - struct { - Foo string `validate:"presence"` - }{}, - false, - false, - }, - { - "required empty slice", - struct { - Foo []string `validate:"presence"` - }{}, - false, - false, - }, - { - "not a struct", - 42, - false, - true, - }, - { - "slice", - []interface{}{}, - false, - true, - }, - { - "map", - map[string]interface{}{}, - false, - true, - }, - } - - for _, test := range validateTests { - test := test - t.Run(test.name, func(t *testing.T) { - defer func() { - if err := recover(); err != nil && !test.shouldPanic { - t.Fatal("unexpected panic received: err:", err) - } else if err == nil && test.shouldPanic { - t.Fatal("expected panic but received none") - } - }() - t.Parallel() - - err := Validate(test.sub) - if err != nil && test.shouldBeValid { - t.Fatal("unexpected error validating: err:", err) - } - - if err == nil && !test.shouldBeValid { - t.Fatal("expected subject to be invalid but wasn't") - } - }) - } -} diff --git a/nmagent/requests.go b/nmagent/requests.go index 3734ee1a52..9f02b0e4fe 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -40,16 +40,16 @@ type PutNetworkContainerRequest struct { VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet // Version is the new network container version - Version uint64 `validate:"presence" json:"version"` + Version uint64 `json:"version"` // SubnetName is the name of the delegated subnet. This is used to // authenticate the request. The list of ipv4addresses must be contained in // the subnet's prefix. - SubnetName string `validate:"presence" json:"subnetName"` + SubnetName string `json:"subnetName"` // IPv4 addresses in the customer virtual network that will be assigned to // the interface. - IPv4Addrs []string `validate:"presence" json:"ipV4Addresses"` + IPv4Addrs []string `json:"ipV4Addresses"` Policies []Policy `json:"policies"` // policies applied to the network container @@ -58,7 +58,7 @@ type PutNetworkContainerRequest struct { VlanID int `json:"vlanId"` // VirtualNetworkID is the ID of the customer's virtual network - VirtualNetworkID string `validate:"presence" json:"virtualNetworkId"` + VirtualNetworkID string `json:"virtualNetworkId"` // AuthenticationToken is the base64 security token for the subnet containing // the Network Container addresses @@ -94,7 +94,28 @@ func (p PutNetworkContainerRequest) Path() string { // Validate ensures that all of the required parameters of the request have // been filled out properly prior to submission to NMAgent func (p PutNetworkContainerRequest) Validate() error { - return internal.Validate(p) + err := internal.ValidationError{} + + if p.Version == 0 { + err.MissingFields = append(err.MissingFields, "Version") + } + + if p.SubnetName == "" { + err.MissingFields = append(err.MissingFields, "SubnetName") + } + + if len(p.IPv4Addrs) == 0 { + err.MissingFields = append(err.MissingFields, "IPv4Addrs") + } + + if p.VirtualNetworkID == "" { + err.MissingFields = append(err.MissingFields, "VirtualNetworkID") + } + + if err.IsEmpty() { + return nil + } + return err } type Policy struct { @@ -157,7 +178,16 @@ func (j JoinNetworkRequest) Method() string { // Validate ensures that the provided parameters of the request are valid func (j JoinNetworkRequest) Validate() error { - return internal.Validate(j) + err := internal.ValidationError{} + + if j.NetworkID == "" { + err.MissingFields = append(err.MissingFields, "NetworkID") + } + + if err.IsEmpty() { + return nil + } + return err } var _ Request = DeleteContainerRequest{} @@ -165,12 +195,12 @@ var _ Request = DeleteContainerRequest{} // DeleteContainerRequest represents all information necessary to request that // NMAgent delete a particular network container type DeleteContainerRequest struct { - NCID string `validate:"presence" json:"-"` // the Network Container ID + NCID string `json:"-"` // the Network Container ID // PrimaryAddress is the primary customer address of the interface in the // management VNET - PrimaryAddress string `validate:"presence" json:"-"` - AuthenticationToken string `validate:"presence" json:"-"` + PrimaryAddress string `json:"-"` + AuthenticationToken string `json:"-"` } // Path returns the path for submitting a DeleteContainerRequest with @@ -193,7 +223,24 @@ func (d DeleteContainerRequest) Method() string { // Validate ensures that the DeleteContainerRequest has the correct information // to submit the request func (d DeleteContainerRequest) Validate() error { - return internal.Validate(d) + err := internal.ValidationError{} + + if d.NCID == "" { + err.MissingFields = append(err.MissingFields, "NCID") + } + + if d.PrimaryAddress == "" { + err.MissingFields = append(err.MissingFields, "PrimaryAddress") + } + + if d.AuthenticationToken == "" { + err.MissingFields = append(err.MissingFields, "AuthenticationToken") + } + + if err.IsEmpty() { + return nil + } + return err } var _ Request = GetNetworkConfigRequest{} @@ -223,5 +270,14 @@ func (g GetNetworkConfigRequest) Method() string { // Validate ensures that the request is complete and the parameters are correct func (g GetNetworkConfigRequest) Validate() error { - return internal.Validate(g) + err := internal.ValidationError{} + + if g.VNetID == "" { + err.MissingFields = append(err.MissingFields, "VNetID") + } + + if err.IsEmpty() { + return nil + } + return err } From d4f4d628d70c369cd81fc59b523772b4a97524e7 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 12 Apr 2022 12:57:53 -0400 Subject: [PATCH 33/44] Address Linter Feedback The linter flagged many items here because it wasn't being run locally during development. This addresses all of the feedback. --- nmagent/client.go | 15 ++++----- nmagent/client_test.go | 17 +++++----- nmagent/error.go | 3 +- nmagent/internal/errors.go | 9 ++++++ nmagent/internal/retry.go | 10 ++++-- nmagent/internal/retry_example_test.go | 2 +- nmagent/internal/retry_test.go | 3 +- nmagent/internal/wireserver.go | 19 +++++++----- nmagent/internal/wireserver_test.go | 43 ++++++++++++++++---------- nmagent/requests.go | 22 +++++++------ nmagent/requests_test.go | 1 - 11 files changed, 85 insertions(+), 59 deletions(-) create mode 100644 nmagent/internal/errors.go diff --git a/nmagent/client.go b/nmagent/client.go index 77b5a45e7e..ff39ace75d 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -9,9 +9,8 @@ import ( "strconv" "time" - "github.com/pkg/errors" - "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/pkg/errors" ) // NewClient returns an initialized Client using the provided configuration @@ -30,6 +29,7 @@ func NewClient(c Config) (*Client, error) { port: c.Port, enableTLS: c.UseTLS, retrier: internal.Retrier{ + // nolint:gomnd // the base parameter is explained in the function Cooldown: internal.Exponential(1*time.Second, 2), }, } @@ -78,7 +78,7 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error } err = c.retrier.Do(ctx, func() error { - resp, err := c.httpClient.Do(req) + resp, err := c.httpClient.Do(req) // nolint:govet // the shadow is intentional if err != nil { return errors.Wrap(err, "executing request") } @@ -90,7 +90,7 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error return nil }) - return err + return err // nolint:wrapcheck // wrapping this just introduces noise } // GetNetworkConfiguration retrieves the configuration of a customer's virtual @@ -106,7 +106,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon } err = c.retrier.Do(ctx, func() error { - resp, err := c.httpClient.Do(req) + resp, err := c.httpClient.Do(req) // nolint:govet // the shadow is intentional if err != nil { return errors.Wrap(err, "executing http request to") } @@ -129,12 +129,12 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon return nil }) - return out, err + return out, err // nolint:wrapcheck // wrapping just introduces noise here } // PutNetworkContainer applies a Network Container goal state and publishes it // to PubSub -func (c *Client) PutNetworkContainer(ctx context.Context, pncr PutNetworkContainerRequest) error { +func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContainerRequest) error { requestStart := time.Now() req, err := c.buildRequest(ctx, pncr) @@ -198,6 +198,7 @@ func (c *Client) buildRequest(ctx context.Context, req Request) (*http.Request, return nil, errors.Wrap(err, "retrieving request body") } + // nolint:wrapcheck // wrapping doesn't provide useful information return http.NewRequestWithContext(ctx, req.Method(), fullURL.String(), body) } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 55c977ce58..54a5bcaf59 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -8,10 +8,9 @@ import ( "net/http/httptest" "testing" + "github.com/Azure/azure-container-networking/nmagent" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" - - "github.com/Azure/azure-container-networking/nmagent" ) var _ http.RoundTripper = &TestTripper{} @@ -68,7 +67,7 @@ func TestNMAgentClientJoinNetwork(t *testing.T) { RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.Path rr := httptest.NewRecorder() - rr.Write([]byte(fmt.Sprintf(`{"httpStatusCode":"%d"}`, test.respStatus))) + _, _ = fmt.Fprintf(rr, `{"httpStatusCode":"%d"}`, test.respStatus) rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, @@ -117,7 +116,7 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { } else { rr.WriteHeader(http.StatusOK) } - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) return rr.Result(), nil }, }) @@ -248,7 +247,7 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { } // we still need a fake response - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) return rr.Result(), nil }, }) @@ -276,13 +275,13 @@ func TestNMAgentGetNetworkConfigRetry(t *testing.T) { func TestNMAgentPutNetworkContainer(t *testing.T) { putNCTests := []struct { name string - req nmagent.PutNetworkContainerRequest + req *nmagent.PutNetworkContainerRequest shouldCall bool shouldErr bool }{ { "happy path", - nmagent.PutNetworkContainerRequest{ + &nmagent.PutNetworkContainerRequest{ ID: "350f1e3c-4283-4f51-83a1-c44253962ef1", Version: uint64(12345), VNetID: "be3a33e-61e3-42c7-bd23-6b949f57bd36", @@ -317,7 +316,7 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { client := nmagent.NewTestClient(&TestTripper{ RoundTripF: func(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) rr.WriteHeader(http.StatusOK) didCall = true return rr.Result(), nil @@ -371,7 +370,7 @@ func TestNMAgentDeleteNC(t *testing.T) { RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.Path rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) return rr.Result(), nil }, }) diff --git a/nmagent/error.go b/nmagent/error.go index 31d5595bb3..bbd3ac95d3 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -7,9 +7,8 @@ import ( "net/http" "time" - pkgerrors "github.com/pkg/errors" - "github.com/Azure/azure-container-networking/nmagent/internal" + pkgerrors "github.com/pkg/errors" ) // ContentError is encountered when an unexpected content type is obtained from diff --git a/nmagent/internal/errors.go b/nmagent/internal/errors.go new file mode 100644 index 0000000000..8f4a160057 --- /dev/null +++ b/nmagent/internal/errors.go @@ -0,0 +1,9 @@ +package internal + +// Error represents an internal sentinal error which can be defined as a +// constant. +type Error string + +func (e Error) Error() string { + return string(e) +} diff --git a/nmagent/internal/retry.go b/nmagent/internal/retry.go index 339f8c9de5..8f27b7333f 100644 --- a/nmagent/internal/retry.go +++ b/nmagent/internal/retry.go @@ -3,7 +3,6 @@ package internal import ( "context" "errors" - "fmt" "math" "time" @@ -14,6 +13,10 @@ const ( noDelay = 0 * time.Nanosecond ) +const ( + ErrMaxAttempts = Error("maximum attempts reached") +) + // TemporaryError is an error that can indicate whether it may be resolved with // another attempt type TemporaryError interface { @@ -35,6 +38,7 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { for { if err := ctx.Err(); err != nil { + // nolint:wrapcheck // no meaningful information can be added to this error return err } @@ -43,7 +47,7 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { // check to see if it's temporary var tempErr TemporaryError if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { - delay, err := cooldown() + delay, err := cooldown() // nolint:govet // the shadow is intentional if err != nil { return pkgerrors.Wrap(err, "sleeping during retry") } @@ -75,7 +79,7 @@ func Max(limit int, factory CooldownFactory) CooldownFactory { count := 0 return func() (time.Duration, error) { if count >= limit { - return noDelay, fmt.Errorf("maximum attempts reached (%d)", limit) + return noDelay, ErrMaxAttempts } delay, err := cooldown() diff --git a/nmagent/internal/retry_example_test.go b/nmagent/internal/retry_example_test.go index 1d3f43cc1b..c66bc194fb 100644 --- a/nmagent/internal/retry_example_test.go +++ b/nmagent/internal/retry_example_test.go @@ -63,6 +63,6 @@ func ExampleMax() { // 10ms // 10ms // 10ms - // error cooling down: maximum attempts reached (4) + // error cooling down: maximum attempts reached } } diff --git a/nmagent/internal/retry_test.go b/nmagent/internal/retry_test.go index 70fe00481f..55824de38b 100644 --- a/nmagent/internal/retry_test.go +++ b/nmagent/internal/retry_test.go @@ -34,7 +34,6 @@ func TestBackoffRetry(t *testing.T) { } return nil }) - if err != nil { t.Fatal("unexpected error: err:", err) } @@ -83,7 +82,7 @@ func TestBackoffRetryUnretriableError(t *testing.T) { } err := rt.Do(context.Background(), func() error { - return errors.New("boom") + return errors.New("boom") // nolint:goerr113 // it's just a test }) if err == nil { diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 85c6fa3273..04f0f26f0e 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "errors" - "fmt" "io" "net/http" "strconv" @@ -14,9 +13,10 @@ import ( ) const ( + // nolint:gomnd // constantizing just obscures meaning here _ int64 = 1 << (10 * iota) kilobyte - megabyte + // megabyte ) const ( @@ -26,6 +26,9 @@ const ( // the event that no Content-Length is provided. The responses are relatively // small, so the smallest page size should be sufficient DefaultBufferSize int64 = 4 * kilobyte + + // errors + ErrNoStatusCode = Error("no httpStatusCode property returned in Wireserver response") ) var _ http.RoundTripper = &WireserverTransport{} @@ -42,13 +45,13 @@ func (w WireserverResponse) StatusCode() (int, error) { return 0, pkgerrors.Wrap(err, "unmarshaling httpStatusCode from Wireserver") } - if code, err := strconv.Atoi(statusStr); err != nil { + code, err := strconv.Atoi(statusStr) + if err != nil { return code, pkgerrors.Wrap(err, "parsing http status code from wireserver") - } else { - return code, nil } + return code, nil } - return 0, fmt.Errorf("no httpStatusCode property returned in Wireserver response") + return 0, ErrNoStatusCode } // WireserverTransport is an http.RoundTripper that applies transformation @@ -100,7 +103,7 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er // execute the request to the downstream transport resp, err := w.Transport.RoundTrip(req) if err != nil { - return nil, err + return nil, pkgerrors.Wrap(err, "executing request to wireserver") } if resp.StatusCode != http.StatusOK { @@ -145,6 +148,8 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er // unmodified resp.Header.Set(HeaderContentType, http.DetectContentType(body)) resp.Body = io.NopCloser(bytes.NewReader(body)) + + // nolint:nilerr // we effectively "fix" this error because it's expected return resp, nil } diff --git a/nmagent/internal/wireserver_test.go b/nmagent/internal/wireserver_test.go index 94594ce99d..c9c4cdeb4d 100644 --- a/nmagent/internal/wireserver_test.go +++ b/nmagent/internal/wireserver_test.go @@ -62,7 +62,7 @@ func TestWireserverTransportPathTransform(t *testing.T) { got = r.URL.Path rr := httptest.NewRecorder() rr.WriteHeader(http.StatusOK) - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) return rr.Result(), nil }, }, @@ -71,15 +71,17 @@ func TestWireserverTransportPathTransform(t *testing.T) { // execute - req, err := http.NewRequest(test.method, test.sub, nil) + //nolint:noctx // just a test + req, err := http.NewRequest(test.method, test.sub, http.NoBody) if err != nil { t.Fatal("error creating new request: err:", err) } - _, err = client.Do(req) + resp, err := client.Do(req) if err != nil { t.Fatal("unexpected error submitting request: err:", err) } + defer resp.Body.Close() // assert if got != test.exp { @@ -147,7 +149,8 @@ func TestWireserverTransportStatusTransform(t *testing.T) { // execute - req, err := http.NewRequest(http.MethodGet, "/test/path", nil) + // nolint:noctx // just a test + req, err := http.NewRequest(http.MethodGet, "/test/path", http.NoBody) if err != nil { t.Fatal("error creating new request: err:", err) } @@ -188,7 +191,7 @@ func TestWireserverTransportPutPost(t *testing.T) { RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.Method rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, @@ -196,15 +199,16 @@ func TestWireserverTransportPutPost(t *testing.T) { }, } - req, err := http.NewRequest(http.MethodPut, "/test/path", nil) + req, err := http.NewRequest(http.MethodPut, "/test/path", http.NoBody) if err != nil { t.Fatal("unexpected error creating http request: err:", err) } - _, err = client.Do(req) + resp, err := client.Do(req) if err != nil { t.Fatal("error submitting request: err:", err) } + defer resp.Body.Close() exp := http.MethodPost if got != exp { @@ -223,7 +227,7 @@ func TestWireserverTransportPostBody(t *testing.T) { RoundTripF: func(req *http.Request) (*http.Response, error) { bodyIsNil = req.Body == nil rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, @@ -231,31 +235,34 @@ func TestWireserverTransportPostBody(t *testing.T) { }, } - // PUT - req, err := http.NewRequest(http.MethodPut, "/test/path", nil) + // PUT request + req, err := http.NewRequest(http.MethodPut, "/test/path", http.NoBody) if err != nil { t.Fatal("unexpected error creating http request: err:", err) } - _, err = client.Do(req) + resp, err := client.Do(req) if err != nil { t.Fatal("error submitting request: err:", err) } + defer resp.Body.Close() if bodyIsNil { t.Error("downstream request body to wireserver was nil, but not expected to be") } - // POST - req, err = http.NewRequest(http.MethodPost, "/test/path", nil) + // POST request + // nolint:noctx // just a test + req, err = http.NewRequest(http.MethodPost, "/test/path", http.NoBody) if err != nil { t.Fatal("unexpected error creating http request: err:", err) } - _, err = client.Do(req) + resp, err = client.Do(req) if err != nil { t.Fatal("error submitting request: err:", err) } + defer resp.Body.Close() if bodyIsNil { t.Error("downstream request body to wireserver was nil, but not expected to be") @@ -274,7 +281,7 @@ func TestWireserverTransportQuery(t *testing.T) { RoundTripF: func(req *http.Request) (*http.Response, error) { got = req.URL.Path rr := httptest.NewRecorder() - rr.Write([]byte(`{"httpStatusCode": "200"}`)) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) rr.WriteHeader(http.StatusOK) return rr.Result(), nil }, @@ -282,15 +289,17 @@ func TestWireserverTransportQuery(t *testing.T) { }, } - req, err := http.NewRequest(http.MethodPut, "/test/path?api-version=1234&foo=bar", nil) + // nolint:noctx // just a test + req, err := http.NewRequest(http.MethodPut, "/test/path?api-version=1234&foo=bar", http.NoBody) if err != nil { t.Fatal("unexpected error creating http request: err:", err) } - _, err = client.Do(req) + resp, err := client.Do(req) if err != nil { t.Fatal("error submitting request: err:", err) } + defer resp.Body.Close() exp := "/machine/plugins/?comp=nmagent&type=test/path/api-version/1234/foo/bar" if got != exp { diff --git a/nmagent/requests.go b/nmagent/requests.go index 9f02b0e4fe..3654f73a5d 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -9,13 +9,12 @@ import ( "strings" "unicode" - "github.com/pkg/errors" - "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/pkg/errors" ) // Request represents an abstracted HTTP request, capable of validating itself, -// producting a valid Path, Body, and its Method +// producing a valid Path, Body, and its Method. type Request interface { // Validate should ensure that the request is valid to submit Validate() error @@ -31,7 +30,7 @@ type Request interface { Method() string } -var _ Request = PutNetworkContainerRequest{} +var _ Request = &PutNetworkContainerRequest{} // PutNetworkContainerRequest is a collection of parameters necessary to create // a new network container @@ -71,7 +70,7 @@ type PutNetworkContainerRequest struct { // Body marshals the JSON fields of the request and produces an Reader intended // for use with an HTTP request -func (p PutNetworkContainerRequest) Body() (io.Reader, error) { +func (p *PutNetworkContainerRequest) Body() (io.Reader, error) { body, err := json.Marshal(p) if err != nil { return nil, errors.Wrap(err, "marshaling PutNetworkContainerRequest") @@ -81,19 +80,19 @@ func (p PutNetworkContainerRequest) Body() (io.Reader, error) { } // Method returns the HTTP method for this request type -func (p PutNetworkContainerRequest) Method() string { +func (p *PutNetworkContainerRequest) Method() string { return http.MethodPost } // Path returns the URL path necessary to submit this PutNetworkContainerRequest -func (p PutNetworkContainerRequest) Path() string { +func (p *PutNetworkContainerRequest) Path() string { const PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" return fmt.Sprintf(PutNCRequestPath, p.PrimaryAddress, p.ID, p.AuthenticationToken) } // Validate ensures that all of the required parameters of the request have // been filled out properly prior to submission to NMAgent -func (p PutNetworkContainerRequest) Validate() error { +func (p *PutNetworkContainerRequest) Validate() error { err := internal.ValidationError{} if p.Version == 0 { @@ -131,11 +130,14 @@ func (p Policy) MarshalJSON() ([]byte, error) { out.WriteString(p.Type) outStr := out.String() + // nolint:wrapcheck // wrapping this error provides no useful information return json.Marshal(outStr) } // UnmarshalJSON decodes a JSON-encoded policy string func (p *Policy) UnmarshalJSON(in []byte) error { + const expectedNumParts = 2 + var raw string err := json.Unmarshal(in, &raw) if err != nil { @@ -143,8 +145,8 @@ func (p *Policy) UnmarshalJSON(in []byte) error { } parts := strings.Split(raw, ",") - if len(parts) != 2 { - return fmt.Errorf("policies must be two comma-separated values") + if len(parts) != expectedNumParts { + return errors.New("policies must be two comma-separated values") } p.ID = strings.TrimFunc(parts[0], unicode.IsSpace) diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index 28dc119930..9d1069c47e 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/Azure/azure-container-networking/nmagent" - "github.com/google/go-cmp/cmp" ) From 2bd7f8ed38424213a308ad4eda2c0ef04c56b85e Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Tue, 12 Apr 2022 13:11:34 -0400 Subject: [PATCH 34/44] Remove the UnauthorizedGracePeriod NMAgent only defines 102 processing as a temporary status. It's up to consumers of the client to determine whether an unauthorized status means that it should be retried or not. --- nmagent/client.go | 26 ++++---------------------- nmagent/error.go | 20 +++++++++----------- nmagent/nmagent_test.go | 29 +---------------------------- 3 files changed, 14 insertions(+), 61 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index ff39ace75d..d1699b3da7 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -70,8 +70,6 @@ func WithUnauthorizedGracePeriod(grace time.Duration) Option { // JoinNetwork joins a node to a customer's virtual network func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { - requestStart := time.Now() - req, err := c.buildRequest(ctx, jnr) if err != nil { return errors.Wrap(err, "building request") @@ -85,7 +83,7 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(time.Since(requestStart), resp.StatusCode) + return Error{resp.StatusCode} } return nil }) @@ -96,8 +94,6 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error // GetNetworkConfiguration retrieves the configuration of a customer's virtual // network. Only subnets which have been delegated will be returned func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkConfigRequest) (VirtualNetwork, error) { - requestStart := time.Now() - var out VirtualNetwork req, err := c.buildRequest(ctx, gncr) @@ -113,7 +109,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(time.Since(requestStart), resp.StatusCode) + return Error{resp.StatusCode} } ct := resp.Header.Get(internal.HeaderContentType) @@ -135,8 +131,6 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon // PutNetworkContainer applies a Network Container goal state and publishes it // to PubSub func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContainerRequest) error { - requestStart := time.Now() - req, err := c.buildRequest(ctx, pncr) if err != nil { return errors.Wrap(err, "building request") @@ -149,7 +143,7 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContai defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(time.Since(requestStart), resp.StatusCode) + return Error{resp.StatusCode} } return nil } @@ -157,8 +151,6 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContai // DeleteNetworkContainer removes a Network Container, its associated IP // addresses, and network policies from an interface func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainerRequest) error { - requestStart := time.Now() - req, err := c.buildRequest(ctx, dcr) if err != nil { return errors.Wrap(err, "building request") @@ -171,7 +163,7 @@ func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainer defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(time.Since(requestStart), resp.StatusCode) + return Error{resp.StatusCode} } return nil @@ -208,13 +200,3 @@ func (c *Client) scheme() string { } return "http" } - -// error constructs a NMAgent error while providing some information configured -// at instantiation -func (c *Client) error(runtime time.Duration, code int) error { - return Error{ - Runtime: runtime, - Limit: c.unauthorizedGracePeriod, - Code: code, - } -} diff --git a/nmagent/error.go b/nmagent/error.go index bbd3ac95d3..fe8bd71382 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/Azure/azure-container-networking/nmagent/internal" pkgerrors "github.com/pkg/errors" @@ -50,9 +49,7 @@ func NewContentError(contentType string, in io.Reader, limit int64) error { // Error is a aberrent condition encountered when interacting with the NMAgent // API type Error struct { - Runtime time.Duration // the amount of time the operation has been running - Limit time.Duration // the maximum amount of time the operation can run for - Code int // the HTTP status code received + Code int // the HTTP status code received } // Error constructs a string representation of this error in accordance with @@ -79,14 +76,15 @@ func (e Error) Message() string { // Temporary reports whether the error encountered from NMAgent should be // considered temporary, and thus retriable func (e Error) Temporary() bool { - // We consider Unauthorized responses from NMAgent to be temporary for a - // certain period of time. This is to allow for situations where an - // authorization token may not yet be available - if e.Code == http.StatusUnauthorized { - return e.Runtime < e.Limit - } - // NMAgent will return a 102 (Processing) if the request is taking time to // complete. These should be attempted again. return e.Code == http.StatusProcessing } + +// Unauthorized reports whether the error was produced as a result of +// submitting the request from an interface without an OwningServiceInstanceId +// property. In some cases, this can be a transient condition that could be +// retried. +func (e Error) Unauthorized() bool { + return e.Code == http.StatusUnauthorized +} diff --git a/nmagent/nmagent_test.go b/nmagent/nmagent_test.go index 79068a3e41..464a1528c0 100644 --- a/nmagent/nmagent_test.go +++ b/nmagent/nmagent_test.go @@ -7,7 +7,6 @@ import ( "net/http" "strings" "testing" - "time" "github.com/Azure/azure-container-networking/nmagent" ) @@ -33,38 +32,12 @@ func TestErrorTemp(t *testing.T) { true, }, { - "unauthorized temporary", - nmagent.Error{ - Runtime: 30 * time.Second, - Limit: 1 * time.Minute, - Code: http.StatusUnauthorized, - }, - true, - }, - { - "unauthorized permanent", - nmagent.Error{ - Runtime: 2 * time.Minute, - Limit: 1 * time.Minute, - Code: http.StatusUnauthorized, - }, - false, - }, - { - "unauthorized zero values", + "unauthorized", nmagent.Error{ Code: http.StatusUnauthorized, }, false, }, - { - "unauthorized zero limit", - nmagent.Error{ - Runtime: 2 * time.Minute, - Code: http.StatusUnauthorized, - }, - false, - }, } for _, test := range errorTests { From 196e5c5da0be639424a426fdfdc76a9d83add5d3 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Thu, 14 Apr 2022 17:14:05 -0400 Subject: [PATCH 35/44] Add error source to NMA error One of the problems with using the WireserverTransport to modify the http status code is that it obscures the source of those errors. Should there be an issue with NMAgent or Wireserver, it will be difficult (or impossible) to figure out which is which. The error itself should tell you, and WireserverTransport knows which component is responsible. This adds a header to the HTTP response and uses that to communicate the responsible party. This is then wired into the outgoing error so that clients can take appropriate action. --- nmagent/client.go | 17 +++++++--- nmagent/error.go | 15 +++++++-- nmagent/internal/wireserver.go | 59 +++++++++++++++++++++++++++++++++- 3 files changed, 83 insertions(+), 8 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index d1699b3da7..0c02b3a9ce 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -83,7 +83,7 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return Error{resp.StatusCode} + return c.error(resp.StatusCode, resp.Header) } return nil }) @@ -109,7 +109,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return Error{resp.StatusCode} + return c.error(resp.StatusCode, resp.Header) } ct := resp.Header.Get(internal.HeaderContentType) @@ -143,7 +143,7 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContai defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return Error{resp.StatusCode} + return c.error(resp.StatusCode, resp.Header) } return nil } @@ -163,12 +163,21 @@ func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainer defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return Error{resp.StatusCode} + return c.error(resp.StatusCode, resp.Header) } return nil } +func (c *Client) error(code int, headers http.Header) error { + return Error{ + Code: code, + // this is a little strange, but the conversion below is to avoid forcing + // consumers to depend on an internal type (which they can't anyway) + Source: internal.GetErrorSource(headers).String(), + } +} + func (c *Client) hostPort() string { port := strconv.Itoa(int(c.port)) return net.JoinHostPort(c.host, port) diff --git a/nmagent/error.go b/nmagent/error.go index fe8bd71382..a0dafde00b 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -49,13 +49,22 @@ func NewContentError(contentType string, in io.Reader, limit int64) error { // Error is a aberrent condition encountered when interacting with the NMAgent // API type Error struct { - Code int // the HTTP status code received + Code int // the HTTP status code received + Source string // the component responsible for producing the error } // Error constructs a string representation of this error in accordance with // the error interface func (e Error) Error() string { - return fmt.Sprintf("nmagent: http status %d: %s", e.Code, e.Message()) + return fmt.Sprintf("nmagent: %s: http status %d: %s", e.source(), e.Code, e.Message()) +} + +func (e Error) source() string { + source := "not provided" + if e.Source != "" { + source = e.Source + } + return fmt.Sprintf("source: %s", source) } // Message interprets the HTTP Status code from NMAgent and returns the @@ -69,7 +78,7 @@ func (e Error) Message() string { case http.StatusInternalServerError: return "error occurred during nmagent's request processing" default: - return "undocumented nmagent error" + return "undocumented error" } } diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 04f0f26f0e..3251edf6fd 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -29,8 +29,59 @@ const ( // errors ErrNoStatusCode = Error("no httpStatusCode property returned in Wireserver response") + + // Headers + HeaderErrorSource = "X-Error-Source" +) + +// ErrorSource is an indicator used as a header value to indicate the source of +// non-2xx status codes +type ErrorSource int + +const ( + ErrorSourceInvalid ErrorSource = iota + ErrorSourceWireserver + ErrorSourceNMAgent ) +// String produces the string equivalent for the ErrorSource type +func (e ErrorSource) String() string { + switch e { + case ErrorSourceWireserver: + return "wireserver" + case ErrorSourceNMAgent: + return "nmagent" + case ErrorSourceInvalid: + return "" + default: + return "" + } +} + +// NewErrorSource produces an ErrorSource value from the provided string. Any +// unrecognized values will become the invalid type +func NewErrorSource(es string) ErrorSource { + switch es { + case "wireserver": + return ErrorSourceWireserver + case "nmagent": + return ErrorSourceNMAgent + default: + return ErrorSourceInvalid + } +} + +// GetErrorSource retrieves the error source from the provided HTTP headers +func GetErrorSource(head http.Header) ErrorSource { + return NewErrorSource(head.Get(HeaderErrorSource)) +} + +// SetErrorSource sets the header value necessary for communicating the error +// source +func SetErrorSource(head *http.Header, es ErrorSource) { + head.Set(HeaderErrorSource, es.String()) +} + var _ http.RoundTripper = &WireserverTransport{} // WireserverResponse represents a raw response from Wireserver @@ -107,7 +158,9 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er } if resp.StatusCode != http.StatusOK { - // something happened at Wireserver, so we should just hand this back up + // something happened at Wireserver, so set a header implicating Wireserver + // and hand the response back up + SetErrorSource(&resp.Header, ErrorSourceWireserver) return resp, nil } @@ -162,6 +215,10 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er return resp, pkgerrors.Wrap(err, "retrieving status code from wireserver response") } + // add the advisory header stating that any HTTP Status from here out is from + // NMAgent + SetErrorSource(&resp.Header, ErrorSourceNMAgent) + resp.StatusCode = realCode // re-encode the body and re-attach it to the response From 1fdf92429ec15e4e4f6a1ce8ee21c0312330c3a6 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 15 Apr 2022 17:41:59 -0400 Subject: [PATCH 36/44] Remove leftover unauthorizedGracePeriod These blocks escaped notice when the rest of the UnauthorizedGracePeriod logic was removed from the nmagent client. --- nmagent/client.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 0c02b3a9ce..f1196edbbe 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -47,10 +47,6 @@ type Client struct { enableTLS bool - // unauthorizedGracePeriod is the amount of time Unauthorized responses from - // NMAgent will be tolerated and retried - unauthorizedGracePeriod time.Duration - retrier interface { Do(context.Context, func() error) error } @@ -60,14 +56,6 @@ type Client struct { // client type Option func(*Client) -// WithUnauthorizedGracePeriod is an option to treat Unauthorized (401) -// responses from NMAgent as temporary errors for a configurable amount of time -func WithUnauthorizedGracePeriod(grace time.Duration) Option { - return func(c *Client) { - c.unauthorizedGracePeriod = grace - } -} - // JoinNetwork joins a node to a customer's virtual network func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { req, err := c.buildRequest(ctx, jnr) From f5ac638ab14b7a4e835562577cfa69351c39c7e3 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 15 Apr 2022 17:43:23 -0400 Subject: [PATCH 37/44] Remove extra validation tag This validation tag wasn't noticed when the validation struct tags were removed in a previous commit. --- nmagent/requests.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nmagent/requests.go b/nmagent/requests.go index 3654f73a5d..1628e16400 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -250,7 +250,7 @@ var _ Request = GetNetworkConfigRequest{} // GetNetworkConfigRequest is a collection of necessary information for // submitting a request for a customer's network configuration type GetNetworkConfigRequest struct { - VNetID string `validate:"presence" json:"-"` // the customer's virtual network ID + VNetID string `json:"-"` // the customer's virtual network ID } // Path produces a URL path used to submit a request From 11ec25aa898ac4a349fb6f63267eabf824c8fb2a Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 15 Apr 2022 17:44:40 -0400 Subject: [PATCH 38/44] Add the body to the nmagent.Error When errors are returned, it's useful to have the body available for inspection during debugging efforts. This captures the returned body and makes it available in the nmagent.Error. It's also printed when the error is converted to its string representation. --- nmagent/client.go | 19 +++++++++--- nmagent/client_test.go | 43 ++++++++++++++++++++++++++ nmagent/error.go | 3 +- nmagent/internal/errors.go | 54 +++++++++++++++++++++++++++++++++ nmagent/internal/errors_test.go | 38 +++++++++++++++++++++++ nmagent/internal/wireserver.go | 51 ------------------------------- 6 files changed, 151 insertions(+), 57 deletions(-) create mode 100644 nmagent/internal/errors_test.go diff --git a/nmagent/client.go b/nmagent/client.go index f1196edbbe..41714133ce 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -3,6 +3,7 @@ package nmagent import ( "context" "encoding/json" + "io" "net" "net/http" "net/url" @@ -71,7 +72,7 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header) + return c.error(resp.StatusCode, resp.Header, resp.Body) } return nil }) @@ -97,7 +98,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header) + return c.error(resp.StatusCode, resp.Header, resp.Body) } ct := resp.Header.Get(internal.HeaderContentType) @@ -131,7 +132,7 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContai defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header) + return c.error(resp.StatusCode, resp.Header, resp.Body) } return nil } @@ -151,18 +152,26 @@ func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainer defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header) + return c.error(resp.StatusCode, resp.Header, resp.Body) } return nil } -func (c *Client) error(code int, headers http.Header) error { +func (c *Client) error(code int, headers http.Header, body io.ReadCloser) error { + // read the entire body + defer body.Close() + + // nolint:errcheck // make a best effort to return whatever information we can + // returning an error here without the code and source would + // be less helpful + bodyContent, _ := io.ReadAll(body) return Error{ Code: code, // this is a little strange, but the conversion below is to avoid forcing // consumers to depend on an internal type (which they can't anyway) Source: internal.GetErrorSource(headers).String(), + Body: bodyContent, } } diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 54a5bcaf59..190b8e992c 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "github.com/Azure/azure-container-networking/nmagent" @@ -142,6 +143,48 @@ func TestNMAgentClientJoinNetworkRetry(t *testing.T) { } } +func TestWSError(t *testing.T) { + const wsError string = ` + + +InternalError +The server encountered an internal error. Please retry the request. + +
+
+` + + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + rr.WriteHeader(http.StatusInternalServerError) + _, _ = rr.WriteString(wsError) + return rr.Result(), nil + }, + }) + + req := nmagent.GetNetworkConfigRequest{ + VNetID: "4815162342", + } + _, err := client.GetNetworkConfiguration(context.TODO(), req) + + if err == nil { + t.Fatal("expected error to not be nil") + } + + var cerr nmagent.Error + ok := errors.As(err, &cerr) + if !ok { + t.Fatal("error was not an nmagent.Error") + } + + t.Log(cerr.Error()) + if !strings.Contains(cerr.Error(), "InternalError") { + t.Error("error did not contain the error content from wireserver") + } +} + func TestNMAgentGetNetworkConfig(t *testing.T) { getTests := []struct { name string diff --git a/nmagent/error.go b/nmagent/error.go index a0dafde00b..4bbb56e4d2 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -51,12 +51,13 @@ func NewContentError(contentType string, in io.Reader, limit int64) error { type Error struct { Code int // the HTTP status code received Source string // the component responsible for producing the error + Body []byte // the body of the error returned } // Error constructs a string representation of this error in accordance with // the error interface func (e Error) Error() string { - return fmt.Sprintf("nmagent: %s: http status %d: %s", e.source(), e.Code, e.Message()) + return fmt.Sprintf("nmagent: %s: http status %d: %s: body: %s", e.source(), e.Code, e.Message(), string(e.Body)) } func (e Error) source() string { diff --git a/nmagent/internal/errors.go b/nmagent/internal/errors.go index 8f4a160057..8880bab16c 100644 --- a/nmagent/internal/errors.go +++ b/nmagent/internal/errors.go @@ -1,5 +1,7 @@ package internal +import "net/http" + // Error represents an internal sentinal error which can be defined as a // constant. type Error string @@ -7,3 +9,55 @@ type Error string func (e Error) Error() string { return string(e) } + +// ErrorSource is an indicator used as a header value to indicate the source of +// non-2xx status codes +type ErrorSource int + +const ( + ErrorSourceInvalid ErrorSource = iota + ErrorSourceWireserver + ErrorSourceNMAgent +) + +// String produces the string equivalent for the ErrorSource type +func (e ErrorSource) String() string { + switch e { + case ErrorSourceWireserver: + return "wireserver" + case ErrorSourceNMAgent: + return "nmagent" + case ErrorSourceInvalid: + return "" + default: + return "" + } +} + +// NewErrorSource produces an ErrorSource value from the provided string. Any +// unrecognized values will become the invalid type +func NewErrorSource(es string) ErrorSource { + switch es { + case "wireserver": + return ErrorSourceWireserver + case "nmagent": + return ErrorSourceNMAgent + default: + return ErrorSourceInvalid + } +} + +const ( + HeaderErrorSource = "X-Error-Source" +) + +// GetErrorSource retrieves the error source from the provided HTTP headers +func GetErrorSource(head http.Header) ErrorSource { + return NewErrorSource(head.Get(HeaderErrorSource)) +} + +// SetErrorSource sets the header value necessary for communicating the error +// source +func SetErrorSource(head *http.Header, es ErrorSource) { + head.Set(HeaderErrorSource, es.String()) +} diff --git a/nmagent/internal/errors_test.go b/nmagent/internal/errors_test.go new file mode 100644 index 0000000000..43a2518cd4 --- /dev/null +++ b/nmagent/internal/errors_test.go @@ -0,0 +1,38 @@ +package internal + +import ( + "net/http" + "testing" +) + +func TestErrorSource(t *testing.T) { + esTests := []struct { + sub string + exp string + }{ + {"wireserver", "wireserver"}, + {"nmagent", "nmagent"}, + {"garbage", ""}, + {"", ""}, + } + + for _, test := range esTests { + test := test + t.Run(test.sub, func(t *testing.T) { + t.Parallel() + + // since this is intended for use with headers, this tests end-to-end + es := NewErrorSource(test.sub) + + head := http.Header{} + SetErrorSource(&head, es) + gotEs := GetErrorSource(head) + + got := gotEs.String() + + if test.exp != got { + t.Fatal("received value differs from expectation: exp:", test, "got:", got) + } + }) + } +} diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 3251edf6fd..76d2e959bc 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -29,59 +29,8 @@ const ( // errors ErrNoStatusCode = Error("no httpStatusCode property returned in Wireserver response") - - // Headers - HeaderErrorSource = "X-Error-Source" -) - -// ErrorSource is an indicator used as a header value to indicate the source of -// non-2xx status codes -type ErrorSource int - -const ( - ErrorSourceInvalid ErrorSource = iota - ErrorSourceWireserver - ErrorSourceNMAgent ) -// String produces the string equivalent for the ErrorSource type -func (e ErrorSource) String() string { - switch e { - case ErrorSourceWireserver: - return "wireserver" - case ErrorSourceNMAgent: - return "nmagent" - case ErrorSourceInvalid: - return "" - default: - return "" - } -} - -// NewErrorSource produces an ErrorSource value from the provided string. Any -// unrecognized values will become the invalid type -func NewErrorSource(es string) ErrorSource { - switch es { - case "wireserver": - return ErrorSourceWireserver - case "nmagent": - return ErrorSourceNMAgent - default: - return ErrorSourceInvalid - } -} - -// GetErrorSource retrieves the error source from the provided HTTP headers -func GetErrorSource(head http.Header) ErrorSource { - return NewErrorSource(head.Get(HeaderErrorSource)) -} - -// SetErrorSource sets the header value necessary for communicating the error -// source -func SetErrorSource(head *http.Header, es ErrorSource) { - head.Set(HeaderErrorSource, es.String()) -} - var _ http.RoundTripper = &WireserverTransport{} // WireserverResponse represents a raw response from Wireserver From 8774ca9b49dcd7958a51d3e14ec6dc96cf173cb3 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 29 Apr 2022 17:29:31 -0400 Subject: [PATCH 39/44] Remove VirtualNetworkID This was redundant, since VNetID covered the same key. It's actually unclear what would happen in this circumstance if this remained, but since it's incorrect this removes it. --- nmagent/client_test.go | 1 - nmagent/requests.go | 6 ++---- nmagent/requests_test.go | 10 ++-------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/nmagent/client_test.go b/nmagent/client_test.go index 190b8e992c..4e3e286b89 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -343,7 +343,6 @@ func TestNMAgentPutNetworkContainer(t *testing.T) { VlanID: 1234, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", - VirtualNetworkID: "0000000-0000-0000-0000-000000000000", }, true, false, diff --git a/nmagent/requests.go b/nmagent/requests.go index 1628e16400..a03d275766 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -56,8 +56,6 @@ type PutNetworkContainerRequest struct { // addresses. "0" is considered a default value by the API. VlanID int `json:"vlanId"` - // VirtualNetworkID is the ID of the customer's virtual network - VirtualNetworkID string `json:"virtualNetworkId"` // AuthenticationToken is the base64 security token for the subnet containing // the Network Container addresses @@ -107,8 +105,8 @@ func (p *PutNetworkContainerRequest) Validate() error { err.MissingFields = append(err.MissingFields, "IPv4Addrs") } - if p.VirtualNetworkID == "" { - err.MissingFields = append(err.MissingFields, "VirtualNetworkID") + if p.VNetID == "" { + err.MissingFields = append(err.MissingFields, "VNetID") } if err.IsEmpty() { diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index 9d1069c47e..d1e3a863e4 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -245,7 +245,6 @@ func TestPutNetworkContainerRequestPath(t *testing.T) { VlanID: 0, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", - VirtualNetworkID: "33333333-3333-3333-3333-333333333333", }, "/NetworkManagement/interfaces/10.0.0.1/networkContainers/00000000-0000-0000-0000-000000000000/authenticationToken/swordfish/api-version/1", }, @@ -294,7 +293,6 @@ func TestPutNetworkContainerRequestValidate(t *testing.T) { VlanID: 0, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", - VirtualNetworkID: "33333333-3333-3333-3333-333333333333", }, true, }, @@ -315,7 +313,6 @@ func TestPutNetworkContainerRequestValidate(t *testing.T) { VlanID: 0, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", - VirtualNetworkID: "33333333-3333-3333-3333-333333333333", }, false, }, @@ -338,7 +335,6 @@ func TestPutNetworkContainerRequestValidate(t *testing.T) { VlanID: 0, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", - VirtualNetworkID: "33333333-3333-3333-3333-333333333333", }, false, }, @@ -361,15 +357,14 @@ func TestPutNetworkContainerRequestValidate(t *testing.T) { VlanID: 0, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", - VirtualNetworkID: "33333333-3333-3333-3333-333333333333", }, false, }, { - "missing version", + "missing vnet id", nmagent.PutNetworkContainerRequest{ ID: "00000000-0000-0000-0000-000000000000", - VNetID: "11111111-1111-1111-1111-111111111111", + VNetID: "", // the important part Version: uint64(12345), SubnetName: "foo", IPv4Addrs: []string{ @@ -384,7 +379,6 @@ func TestPutNetworkContainerRequestValidate(t *testing.T) { VlanID: 0, AuthenticationToken: "swordfish", PrimaryAddress: "10.0.0.1", - VirtualNetworkID: "", // the important part of the test }, false, }, From cecc8ba02eaa4e8ff4d91b23fcd63eadcec79fd2 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 29 Apr 2022 17:30:37 -0400 Subject: [PATCH 40/44] Add StatusCode to error Clients still want to be able to communicate the status code in logs, so this includes the StatusCode there as well. --- nmagent/error.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nmagent/error.go b/nmagent/error.go index 4bbb56e4d2..99b984b79c 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -91,6 +91,11 @@ func (e Error) Temporary() bool { return e.Code == http.StatusProcessing } +// StatusCode returns the HTTP status associated with this error +func (e Error) StatusCode() int { + return e.Code +} + // Unauthorized reports whether the error was produced as a result of // submitting the request from an interface without an OwningServiceInstanceId // property. In some cases, this can be a transient condition that could be From d04614a3cf5ac11ffc0d0d35b0aa51c9a76bca9a Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 29 Apr 2022 17:31:07 -0400 Subject: [PATCH 41/44] Add GreKey field to PutNetworkContainerRequest In looking at usages, this `greKey` field is undocumented but critical for certain use cases. This adds it so that it remains supported. --- nmagent/requests.go | 1 + 1 file changed, 1 insertion(+) diff --git a/nmagent/requests.go b/nmagent/requests.go index a03d275766..cf2d3d80c6 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -56,6 +56,7 @@ type PutNetworkContainerRequest struct { // addresses. "0" is considered a default value by the API. VlanID int `json:"vlanId"` + GREKey uint16 `json:"greKey"` // AuthenticationToken is the base64 security token for the subnet containing // the Network Container addresses From a9e1d24cd057383a03c2600a29d21c33d27bbf44 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Fri, 29 Apr 2022 17:42:29 -0400 Subject: [PATCH 42/44] Add periods at the end of all docstrings Docstrings should have punctuation since they're documentation. This adds punctuation to every docstring that is exported (and some that aren't). --- nmagent/client.go | 14 +++++++------- nmagent/config.go | 4 ++-- nmagent/error.go | 12 ++++++------ nmagent/internal/errors.go | 10 +++++----- nmagent/internal/retry.go | 16 ++++++++-------- nmagent/internal/wireserver.go | 10 +++++----- 6 files changed, 33 insertions(+), 33 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 41714133ce..2edc35529c 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -14,7 +14,7 @@ import ( "github.com/pkg/errors" ) -// NewClient returns an initialized Client using the provided configuration +// NewClient returns an initialized Client using the provided configuration. func NewClient(c Config) (*Client, error) { if err := c.Validate(); err != nil { return nil, errors.Wrap(err, "validating config") @@ -38,7 +38,7 @@ func NewClient(c Config) (*Client, error) { return client, nil } -// Client is an agent for exchanging information with NMAgent +// Client is an agent for exchanging information with NMAgent. type Client struct { httpClient *http.Client @@ -54,10 +54,10 @@ type Client struct { } // Option is a functional option for configuration optional behavior in the -// client +// client. type Option func(*Client) -// JoinNetwork joins a node to a customer's virtual network +// JoinNetwork joins a node to a customer's virtual network. func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { req, err := c.buildRequest(ctx, jnr) if err != nil { @@ -81,7 +81,7 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error } // GetNetworkConfiguration retrieves the configuration of a customer's virtual -// network. Only subnets which have been delegated will be returned +// network. Only subnets which have been delegated will be returned. func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkConfigRequest) (VirtualNetwork, error) { var out VirtualNetwork @@ -118,7 +118,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon } // PutNetworkContainer applies a Network Container goal state and publishes it -// to PubSub +// to PubSub. func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContainerRequest) error { req, err := c.buildRequest(ctx, pncr) if err != nil { @@ -138,7 +138,7 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContai } // DeleteNetworkContainer removes a Network Container, its associated IP -// addresses, and network policies from an interface +// addresses, and network policies from an interface. func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainerRequest) error { req, err := c.buildRequest(ctx, dcr) if err != nil { diff --git a/nmagent/config.go b/nmagent/config.go index 5514b8eb81..25b99c6fc3 100644 --- a/nmagent/config.go +++ b/nmagent/config.go @@ -2,7 +2,7 @@ package nmagent import "github.com/Azure/azure-container-networking/nmagent/internal" -// Config is a configuration for an NMAgent Client +// Config is a configuration for an NMAgent Client. type Config struct { ///////////////////// // Required Config // @@ -17,7 +17,7 @@ type Config struct { } // Validate reports whether this configuration is a valid configuration for a -// client +// client. func (c Config) Validate() error { err := internal.ValidationError{} diff --git a/nmagent/error.go b/nmagent/error.go index 99b984b79c..83e9592e39 100644 --- a/nmagent/error.go +++ b/nmagent/error.go @@ -11,7 +11,7 @@ import ( ) // ContentError is encountered when an unexpected content type is obtained from -// NMAgent +// NMAgent. type ContentError struct { Type string // the mime type of the content received Body []byte // the received body @@ -47,7 +47,7 @@ func NewContentError(contentType string, in io.Reader, limit int64) error { } // Error is a aberrent condition encountered when interacting with the NMAgent -// API +// API. type Error struct { Code int // the HTTP status code received Source string // the component responsible for producing the error @@ -55,7 +55,7 @@ type Error struct { } // Error constructs a string representation of this error in accordance with -// the error interface +// the error interface. func (e Error) Error() string { return fmt.Sprintf("nmagent: %s: http status %d: %s: body: %s", e.source(), e.Code, e.Message(), string(e.Body)) } @@ -69,7 +69,7 @@ func (e Error) source() string { } // Message interprets the HTTP Status code from NMAgent and returns the -// corresponding explanation from the documentation +// corresponding explanation from the documentation. func (e Error) Message() string { switch e.Code { case http.StatusProcessing: @@ -84,14 +84,14 @@ func (e Error) Message() string { } // Temporary reports whether the error encountered from NMAgent should be -// considered temporary, and thus retriable +// considered temporary, and thus retriable. func (e Error) Temporary() bool { // NMAgent will return a 102 (Processing) if the request is taking time to // complete. These should be attempted again. return e.Code == http.StatusProcessing } -// StatusCode returns the HTTP status associated with this error +// StatusCode returns the HTTP status associated with this error. func (e Error) StatusCode() int { return e.Code } diff --git a/nmagent/internal/errors.go b/nmagent/internal/errors.go index 8880bab16c..11a13d6b85 100644 --- a/nmagent/internal/errors.go +++ b/nmagent/internal/errors.go @@ -11,7 +11,7 @@ func (e Error) Error() string { } // ErrorSource is an indicator used as a header value to indicate the source of -// non-2xx status codes +// non-2xx status codes. type ErrorSource int const ( @@ -20,7 +20,7 @@ const ( ErrorSourceNMAgent ) -// String produces the string equivalent for the ErrorSource type +// String produces the string equivalent for the ErrorSource type. func (e ErrorSource) String() string { switch e { case ErrorSourceWireserver: @@ -35,7 +35,7 @@ func (e ErrorSource) String() string { } // NewErrorSource produces an ErrorSource value from the provided string. Any -// unrecognized values will become the invalid type +// unrecognized values will become the invalid type. func NewErrorSource(es string) ErrorSource { switch es { case "wireserver": @@ -51,13 +51,13 @@ const ( HeaderErrorSource = "X-Error-Source" ) -// GetErrorSource retrieves the error source from the provided HTTP headers +// GetErrorSource retrieves the error source from the provided HTTP headers. func GetErrorSource(head http.Header) ErrorSource { return NewErrorSource(head.Get(HeaderErrorSource)) } // SetErrorSource sets the header value necessary for communicating the error -// source +// source. func SetErrorSource(head *http.Header, es ErrorSource) { head.Set(HeaderErrorSource, es.String()) } diff --git a/nmagent/internal/retry.go b/nmagent/internal/retry.go index 8f27b7333f..9491aea0e4 100644 --- a/nmagent/internal/retry.go +++ b/nmagent/internal/retry.go @@ -18,21 +18,21 @@ const ( ) // TemporaryError is an error that can indicate whether it may be resolved with -// another attempt +// another attempt. type TemporaryError interface { error Temporary() bool } // Retrier is a construct for attempting some operation multiple times with a -// configurable backoff strategy +// configurable backoff strategy. type Retrier struct { Cooldown CooldownFactory } // Do repeatedly invokes the provided run function while the context remains // active. It waits in between invocations of the provided functions by -// delegating to the provided Cooldown function +// delegating to the provided Cooldown function. func (r Retrier) Do(ctx context.Context, run func() error) error { cooldown := r.Cooldown() @@ -44,7 +44,7 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { err := run() if err != nil { - // check to see if it's temporary + // check to see if it's temporary. var tempErr TemporaryError if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { delay, err := cooldown() // nolint:govet // the shadow is intentional @@ -72,7 +72,7 @@ type CooldownFunc func() (time.Duration, error) type CooldownFactory func() CooldownFunc // Max provides a fixed limit for the number of times a subordinate cooldown -// function can be invoked +// function can be invoked. func Max(limit int, factory CooldownFactory) CooldownFactory { return func() CooldownFunc { cooldown := factory() @@ -93,7 +93,7 @@ func Max(limit int, factory CooldownFactory) CooldownFactory { } // AsFastAsPossible is a Cooldown strategy that does not block, allowing retry -// logic to proceed as fast as possible. This is particularly useful in tests +// logic to proceed as fast as possible. This is particularly useful in tests. func AsFastAsPossible() CooldownFactory { return func() CooldownFunc { return func() (time.Duration, error) { @@ -102,7 +102,7 @@ func AsFastAsPossible() CooldownFactory { } } -// Exponential provides an exponential increase the the base interval provided +// Exponential provides an exponential increase the the base interval provided. func Exponential(interval time.Duration, base int) CooldownFactory { return func() CooldownFunc { count := 0 @@ -115,7 +115,7 @@ func Exponential(interval time.Duration, base int) CooldownFactory { } } -// Fixed produced the same delay value upon each invocation +// Fixed produced the same delay value upon each invocation. func Fixed(delay time.Duration) CooldownFactory { return func() CooldownFunc { return func() (time.Duration, error) { diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go index 76d2e959bc..de4f2f1833 100644 --- a/nmagent/internal/wireserver.go +++ b/nmagent/internal/wireserver.go @@ -33,10 +33,10 @@ const ( var _ http.RoundTripper = &WireserverTransport{} -// WireserverResponse represents a raw response from Wireserver +// WireserverResponse represents a raw response from Wireserver. type WireserverResponse map[string]json.RawMessage -// StatusCode extracts the embedded HTTP status code from the response from Wireserver +// StatusCode extracts the embedded HTTP status code from the response from Wireserver. func (w WireserverResponse) StatusCode() (int, error) { if status, ok := w["httpStatusCode"]; ok { var statusStr string @@ -55,14 +55,14 @@ func (w WireserverResponse) StatusCode() (int, error) { } // WireserverTransport is an http.RoundTripper that applies transformation -// rules to inbound requests necessary to make them compatible with Wireserver +// rules to inbound requests necessary to make them compatible with Wireserver. type WireserverTransport struct { Transport http.RoundTripper } // RoundTrip executes arbitrary HTTP requests against Wireserver while applying // the necessary transformation rules to make such requests acceptable to -// Wireserver +// Wireserver. func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, error) { // RoundTrippers are not allowed to modify the request, so we clone it here. // We need to extract the context from the request first since this is _not_ @@ -142,7 +142,7 @@ func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, er resp.ContentLength = int64(numRead) // it's unclear whether Wireserver sets Content-Type appropriately, so we - // attempt to decode it first and surface it otherwise + // attempt to decode it first and surface it otherwise. var wsResp WireserverResponse err = json.Unmarshal(body, &wsResp) if err != nil { From dc9522a7b74bf0c03e58b73a06e67eacfe129e7a Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 2 May 2022 10:18:50 -0400 Subject: [PATCH 43/44] Remove unused Option type This was leftover from a previous cleanup commit. --- nmagent/client.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 2edc35529c..9ad1ab8a1a 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -53,10 +53,6 @@ type Client struct { } } -// Option is a functional option for configuration optional behavior in the -// client. -type Option func(*Client) - // JoinNetwork joins a node to a customer's virtual network. func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { req, err := c.buildRequest(ctx, jnr) From ba69161d173511c29c7a4fd6ed5ad1355685b6c1 Mon Sep 17 00:00:00 2001 From: Tim Raymond Date: Mon, 2 May 2022 10:19:13 -0400 Subject: [PATCH 44/44] Change `error` to a function The `nmagent.(*Client).error` method wasn't actually using any part of `*Client`. Therefore it should be a function. Since we can't use `error` as a function name because it's a reserved keyword, we're throwing back to the Perl days and calling this one `die`. --- nmagent/client.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/nmagent/client.go b/nmagent/client.go index 9ad1ab8a1a..30d518c675 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -68,7 +68,7 @@ func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header, resp.Body) + return die(resp.StatusCode, resp.Header, resp.Body) } return nil }) @@ -94,7 +94,7 @@ func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkCon defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header, resp.Body) + return die(resp.StatusCode, resp.Header, resp.Body) } ct := resp.Header.Get(internal.HeaderContentType) @@ -128,7 +128,7 @@ func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContai defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header, resp.Body) + return die(resp.StatusCode, resp.Header, resp.Body) } return nil } @@ -148,16 +148,13 @@ func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainer defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return c.error(resp.StatusCode, resp.Header, resp.Body) + return die(resp.StatusCode, resp.Header, resp.Body) } return nil } -func (c *Client) error(code int, headers http.Header, body io.ReadCloser) error { - // read the entire body - defer body.Close() - +func die(code int, headers http.Header, body io.ReadCloser) error { // nolint:errcheck // make a best effort to return whatever information we can // returning an error here without the code and source would // be less helpful