diff --git a/nmagent/client.go b/nmagent/client.go new file mode 100644 index 0000000000..30d518c675 --- /dev/null +++ b/nmagent/client.go @@ -0,0 +1,201 @@ +package nmagent + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/pkg/errors" +) + +// NewClient returns an initialized Client using the provided configuration. +func NewClient(c Config) (*Client, error) { + if err := c.Validate(); err != nil { + return nil, errors.Wrap(err, "validating config") + } + + client := &Client{ + httpClient: &http.Client{ + Transport: &internal.WireserverTransport{ + Transport: http.DefaultTransport, + }, + }, + host: c.Host, + port: c.Port, + enableTLS: c.UseTLS, + retrier: internal.Retrier{ + // nolint:gomnd // the base parameter is explained in the function + Cooldown: internal.Exponential(1*time.Second, 2), + }, + } + + return client, nil +} + +// Client is an agent for exchanging information with NMAgent. +type Client struct { + httpClient *http.Client + + // config + host string + port uint16 + + enableTLS bool + + retrier interface { + Do(context.Context, func() error) error + } +} + +// JoinNetwork joins a node to a customer's virtual network. +func (c *Client) JoinNetwork(ctx context.Context, jnr JoinNetworkRequest) error { + req, err := c.buildRequest(ctx, jnr) + if err != nil { + return errors.Wrap(err, "building request") + } + + err = c.retrier.Do(ctx, func() error { + resp, err := c.httpClient.Do(req) // nolint:govet // the shadow is intentional + if err != nil { + return errors.Wrap(err, "executing request") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return die(resp.StatusCode, resp.Header, resp.Body) + } + return nil + }) + + return err // nolint:wrapcheck // wrapping this just introduces noise +} + +// GetNetworkConfiguration retrieves the configuration of a customer's virtual +// network. Only subnets which have been delegated will be returned. +func (c *Client) GetNetworkConfiguration(ctx context.Context, gncr GetNetworkConfigRequest) (VirtualNetwork, error) { + var out VirtualNetwork + + req, err := c.buildRequest(ctx, gncr) + if err != nil { + return out, errors.Wrap(err, "building request") + } + + err = c.retrier.Do(ctx, func() error { + resp, err := c.httpClient.Do(req) // nolint:govet // the shadow is intentional + if err != nil { + return errors.Wrap(err, "executing http request to") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return die(resp.StatusCode, resp.Header, resp.Body) + } + + ct := resp.Header.Get(internal.HeaderContentType) + if ct != internal.MimeJSON { + return NewContentError(ct, resp.Body, resp.ContentLength) + } + + err = json.NewDecoder(resp.Body).Decode(&out) + if err != nil { + return errors.Wrap(err, "decoding json response") + } + + return nil + }) + + return out, err // nolint:wrapcheck // wrapping just introduces noise here +} + +// PutNetworkContainer applies a Network Container goal state and publishes it +// to PubSub. +func (c *Client) PutNetworkContainer(ctx context.Context, pncr *PutNetworkContainerRequest) error { + req, err := c.buildRequest(ctx, pncr) + if err != nil { + return errors.Wrap(err, "building request") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return errors.Wrap(err, "submitting request") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return die(resp.StatusCode, resp.Header, resp.Body) + } + return nil +} + +// DeleteNetworkContainer removes a Network Container, its associated IP +// addresses, and network policies from an interface. +func (c *Client) DeleteNetworkContainer(ctx context.Context, dcr DeleteContainerRequest) error { + req, err := c.buildRequest(ctx, dcr) + if err != nil { + return errors.Wrap(err, "building request") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return errors.Wrap(err, "submitting request") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return die(resp.StatusCode, resp.Header, resp.Body) + } + + return nil +} + +func die(code int, headers http.Header, body io.ReadCloser) error { + // nolint:errcheck // make a best effort to return whatever information we can + // returning an error here without the code and source would + // be less helpful + bodyContent, _ := io.ReadAll(body) + return Error{ + Code: code, + // this is a little strange, but the conversion below is to avoid forcing + // consumers to depend on an internal type (which they can't anyway) + Source: internal.GetErrorSource(headers).String(), + Body: bodyContent, + } +} + +func (c *Client) hostPort() string { + port := strconv.Itoa(int(c.port)) + return net.JoinHostPort(c.host, port) +} + +func (c *Client) buildRequest(ctx context.Context, req Request) (*http.Request, error) { + if err := req.Validate(); err != nil { + return nil, errors.Wrap(err, "validating request") + } + + fullURL := &url.URL{ + Scheme: c.scheme(), + Host: c.hostPort(), + Path: req.Path(), + } + + body, err := req.Body() + if err != nil { + return nil, errors.Wrap(err, "retrieving request body") + } + + // nolint:wrapcheck // wrapping doesn't provide useful information + return http.NewRequestWithContext(ctx, req.Method(), fullURL.String(), body) +} + +func (c *Client) scheme() string { + if c.enableTLS { + return "https" + } + return "http" +} diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go new file mode 100644 index 0000000000..4f98959742 --- /dev/null +++ b/nmagent/client_helpers_test.go @@ -0,0 +1,24 @@ +package nmagent + +import ( + "net/http" + + "github.com/Azure/azure-container-networking/nmagent/internal" +) + +// NewTestClient is a factory function available in tests only for creating +// NMAgent clients with a mock transport +func NewTestClient(transport http.RoundTripper) *Client { + return &Client{ + httpClient: &http.Client{ + Transport: &internal.WireserverTransport{ + Transport: transport, + }, + }, + host: "localhost", + port: 12345, + retrier: internal.Retrier{ + Cooldown: internal.AsFastAsPossible(), + }, + } +} diff --git a/nmagent/client_test.go b/nmagent/client_test.go new file mode 100644 index 0000000000..4e3e286b89 --- /dev/null +++ b/nmagent/client_test.go @@ -0,0 +1,434 @@ +package nmagent_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Azure/azure-container-networking/nmagent" + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" +) + +var _ http.RoundTripper = &TestTripper{} + +// TestTripper is a RoundTripper with a customizeable RoundTrip method for +// testing purposes +type TestTripper struct { + RoundTripF func(*http.Request) (*http.Response, error) +} + +func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripF(req) +} + +func TestNMAgentClientJoinNetwork(t *testing.T) { + joinNetTests := []struct { + name string + id string + exp string + respStatus int + shouldErr bool + }{ + { + "happy path", + "00000000-0000-0000-0000-000000000000", + "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + http.StatusOK, + false, + }, + { + "empty network ID", + "", + "", + http.StatusOK, // this shouldn't be checked + true, + }, + { + "internal error", + "00000000-0000-0000-0000-000000000000", + "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + http.StatusInternalServerError, + true, + }, + } + + for _, test := range joinNetTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // create a client + var got string + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + _, _ = fmt.Fprintf(rr, `{"httpStatusCode":"%d"}`, test.respStatus) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }) + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + // attempt to join network + // TODO(timraymond): need a more realistic network ID, I think + err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{test.id}) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if got != test.exp { + t.Error("received URL differs from expectation: got", got, "exp:", test.exp) + } + }) + } +} + +func TestNMAgentClientJoinNetworkRetry(t *testing.T) { + // we want to ensure that the client will automatically follow up with + // NMAgent, so we want to track the number of requests that it makes + invocations := 0 + exp := 10 + + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if invocations < exp { + rr.WriteHeader(http.StatusProcessing) + invocations++ + } else { + rr.WriteHeader(http.StatusOK) + } + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + return rr.Result(), nil + }, + }) + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + // attempt to join network + err := client.JoinNetwork(ctx, nmagent.JoinNetworkRequest{"00000000-0000-0000-0000-000000000000"}) + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if invocations != exp { + t.Error("client did not make the expected number of API calls: got:", invocations, "exp:", exp) + } +} + +func TestWSError(t *testing.T) { + const wsError string = ` + + +InternalError +The server encountered an internal error. Please retry the request. + +
+
+` + + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + rr.WriteHeader(http.StatusInternalServerError) + _, _ = rr.WriteString(wsError) + return rr.Result(), nil + }, + }) + + req := nmagent.GetNetworkConfigRequest{ + VNetID: "4815162342", + } + _, err := client.GetNetworkConfiguration(context.TODO(), req) + + if err == nil { + t.Fatal("expected error to not be nil") + } + + var cerr nmagent.Error + ok := errors.As(err, &cerr) + if !ok { + t.Fatal("error was not an nmagent.Error") + } + + t.Log(cerr.Error()) + if !strings.Contains(cerr.Error(), "InternalError") { + t.Error("error did not contain the error content from wireserver") + } +} + +func TestNMAgentGetNetworkConfig(t *testing.T) { + getTests := []struct { + name string + vnetID string + expURL string + expVNet map[string]interface{} + shouldCall bool + shouldErr bool + }{ + { + "happy path", + "00000000-0000-0000-0000-000000000000", + "/machine/plugins/?comp=nmagent&type=NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + map[string]interface{}{ + "httpStatusCode": "200", + "cnetSpace": "10.10.1.0/24", + "defaultGateway": "10.10.0.1", + "dnsServers": []string{ + "1.1.1.1", + "1.0.0.1", + }, + "subnets": []map[string]interface{}{}, + "vnetSpace": "10.0.0.0/8", + "vnetVersion": "12345", + }, + true, + false, + }, + } + + for _, test := range getTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var got string + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + got = req.URL.Path + rr.WriteHeader(http.StatusOK) + err := json.NewEncoder(rr).Encode(&test.expVNet) + if err != nil { + return nil, errors.Wrap(err, "encoding response") + } + + return rr.Result(), nil + }, + }) + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + gotVNet, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{test.vnetID}) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if got != test.expURL && test.shouldCall { + t.Error("unexpected URL: got:", got, "exp:", test.expURL) + } + + // TODO(timraymond): this is ugly + expVnet := nmagent.VirtualNetwork{ + CNetSpace: test.expVNet["cnetSpace"].(string), + DefaultGateway: test.expVNet["defaultGateway"].(string), + DNSServers: test.expVNet["dnsServers"].([]string), + Subnets: []nmagent.Subnet{}, + VNetSpace: test.expVNet["vnetSpace"].(string), + VNetVersion: test.expVNet["vnetVersion"].(string), + } + if !cmp.Equal(gotVNet, expVnet) { + t.Error("received vnet differs from expected: diff:", cmp.Diff(gotVNet, expVnet)) + } + }) + } +} + +func TestNMAgentGetNetworkConfigRetry(t *testing.T) { + t.Parallel() + + count := 0 + exp := 10 + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + if count < exp { + rr.WriteHeader(http.StatusProcessing) + count++ + } else { + rr.WriteHeader(http.StatusOK) + } + + // we still need a fake response + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + return rr.Result(), nil + }, + }) + + // if the test provides a timeout, use it in the context + var ctx context.Context + if deadline, ok := t.Deadline(); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() + } else { + ctx = context.Background() + } + + _, err := client.GetNetworkConfiguration(ctx, nmagent.GetNetworkConfigRequest{"00000000-0000-0000-0000-000000000000"}) + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if count != exp { + t.Error("unexpected number of API calls: exp:", exp, "got:", count) + } +} + +func TestNMAgentPutNetworkContainer(t *testing.T) { + putNCTests := []struct { + name string + req *nmagent.PutNetworkContainerRequest + shouldCall bool + shouldErr bool + }{ + { + "happy path", + &nmagent.PutNetworkContainerRequest{ + ID: "350f1e3c-4283-4f51-83a1-c44253962ef1", + Version: uint64(12345), + VNetID: "be3a33e-61e3-42c7-bd23-6b949f57bd36", + SubnetName: "TestSubnet", + IPv4Addrs: []string{"10.0.0.43"}, + Policies: []nmagent.Policy{ + { + ID: "policyID1", + Type: "type1", + }, + { + ID: "policyID2", + Type: "type2", + }, + }, + VlanID: 1234, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + true, + false, + }, + } + + for _, test := range putNCTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + didCall := false + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + rr.WriteHeader(http.StatusOK) + didCall = true + return rr.Result(), nil + }, + }) + + err := client.PutNetworkContainer(context.TODO(), test.req) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if test.shouldCall && !didCall { + t.Fatal("expected call but received none") + } + + if !test.shouldCall && didCall { + t.Fatal("unexpected call. expected no call ") + } + }) + } +} + +func TestNMAgentDeleteNC(t *testing.T) { + deleteTests := []struct { + name string + req nmagent.DeleteContainerRequest + exp string + shouldErr bool + }{ + { + "happy path", + nmagent.DeleteContainerRequest{ + NCID: "00000000-0000-0000-0000-000000000000", + PrimaryAddress: "10.0.0.1", + AuthenticationToken: "swordfish", + }, + "/machine/plugins/?comp=nmagent&type=NetworkManagement/interfaces/10.0.0.1/networkContainers/00000000-0000-0000-0000-000000000000/authenticationToken/swordfish/api-version/1/method/DELETE", + false, + }, + } + + var got string + for _, test := range deleteTests { + test := test + t.Run(test.name, func(t *testing.T) { + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + return rr.Result(), nil + }, + }) + + err := client.DeleteNetworkContainer(context.TODO(), test.req) + if err != nil && !test.shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("expected error but received none") + } + + if test.exp != got { + t.Errorf("received URL differs from expectation:\n\texp: %q:\n\tgot: %q", test.exp, got) + } + }) + } +} diff --git a/nmagent/config.go b/nmagent/config.go new file mode 100644 index 0000000000..25b99c6fc3 --- /dev/null +++ b/nmagent/config.go @@ -0,0 +1,36 @@ +package nmagent + +import "github.com/Azure/azure-container-networking/nmagent/internal" + +// Config is a configuration for an NMAgent Client. +type Config struct { + ///////////////////// + // Required Config // + ///////////////////// + Host string // the host the client will connect to + Port uint16 // the port the client will connect to + + ///////////////////// + // Optional Config // + ///////////////////// + UseTLS bool // forces all connections to use TLS +} + +// Validate reports whether this configuration is a valid configuration for a +// client. +func (c Config) Validate() error { + err := internal.ValidationError{} + + if c.Host == "" { + err.MissingFields = append(err.MissingFields, "Host") + } + + if c.Port == 0 { + err.MissingFields = append(err.MissingFields, "Port") + } + + if err.IsEmpty() { + return nil + } + return err +} diff --git a/nmagent/config_test.go b/nmagent/config_test.go new file mode 100644 index 0000000000..ed9b65f4f9 --- /dev/null +++ b/nmagent/config_test.go @@ -0,0 +1,59 @@ +package nmagent_test + +import ( + "testing" + + "github.com/Azure/azure-container-networking/nmagent" +) + +func TestConfig(t *testing.T) { + configTests := []struct { + name string + config nmagent.Config + expValid bool + }{ + { + "empty", + nmagent.Config{}, + false, + }, + { + "missing port", + nmagent.Config{ + Host: "localhost", + }, + false, + }, + { + "missing host", + nmagent.Config{ + Port: 12345, + }, + false, + }, + { + "complete", + nmagent.Config{ + Host: "localhost", + Port: 12345, + }, + true, + }, + } + + for _, test := range configTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.config.Validate() + if err != nil && test.expValid { + t.Fatal("expected config to be valid but wasnt: err:", err) + } + + if err == nil && !test.expValid { + t.Fatal("expected config to be invalid but wasn't") + } + }) + } +} diff --git a/nmagent/doc.go b/nmagent/doc.go new file mode 100644 index 0000000000..3d3f7fc597 --- /dev/null +++ b/nmagent/doc.go @@ -0,0 +1,3 @@ +// package nmagent contains types and functions necessary for interacting with +// the Network Manager Agent (NMAgent). +package nmagent diff --git a/nmagent/error.go b/nmagent/error.go new file mode 100644 index 0000000000..83e9592e39 --- /dev/null +++ b/nmagent/error.go @@ -0,0 +1,105 @@ +package nmagent + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/Azure/azure-container-networking/nmagent/internal" + pkgerrors "github.com/pkg/errors" +) + +// ContentError is encountered when an unexpected content type is obtained from +// NMAgent. +type ContentError struct { + Type string // the mime type of the content received + Body []byte // the received body +} + +func (c ContentError) Error() string { + if c.Type == internal.MimeOctetStream { + return fmt.Sprintf("unexpected content type %q: body length: %d", c.Type, len(c.Body)) + } + return fmt.Sprintf("unexpected content type %q: body: %s", c.Type, c.Body) +} + +// NewContentError creates a ContentError from a provided reader and limit. +func NewContentError(contentType string, in io.Reader, limit int64) error { + out := ContentError{ + Type: contentType, + Body: make([]byte, limit), + } + + bodyReader := io.LimitReader(in, limit) + + read, err := io.ReadFull(bodyReader, out.Body) + earlyEOF := errors.Is(err, io.ErrUnexpectedEOF) + if err != nil && !earlyEOF { + return pkgerrors.Wrap(err, "reading unexpected content body") + } + + if earlyEOF { + out.Body = out.Body[:read] + } + + return out +} + +// Error is a aberrent condition encountered when interacting with the NMAgent +// API. +type Error struct { + Code int // the HTTP status code received + Source string // the component responsible for producing the error + Body []byte // the body of the error returned +} + +// Error constructs a string representation of this error in accordance with +// the error interface. +func (e Error) Error() string { + return fmt.Sprintf("nmagent: %s: http status %d: %s: body: %s", e.source(), e.Code, e.Message(), string(e.Body)) +} + +func (e Error) source() string { + source := "not provided" + if e.Source != "" { + source = e.Source + } + return fmt.Sprintf("source: %s", source) +} + +// Message interprets the HTTP Status code from NMAgent and returns the +// corresponding explanation from the documentation. +func (e Error) Message() string { + switch e.Code { + case http.StatusProcessing: + return "the request is taking time to process. the caller should try the request again" + case http.StatusUnauthorized: + return "the request did not originate from an interface with an OwningServiceInstanceId property" + case http.StatusInternalServerError: + return "error occurred during nmagent's request processing" + default: + return "undocumented error" + } +} + +// Temporary reports whether the error encountered from NMAgent should be +// considered temporary, and thus retriable. +func (e Error) Temporary() bool { + // NMAgent will return a 102 (Processing) if the request is taking time to + // complete. These should be attempted again. + return e.Code == http.StatusProcessing +} + +// StatusCode returns the HTTP status associated with this error. +func (e Error) StatusCode() int { + return e.Code +} + +// Unauthorized reports whether the error was produced as a result of +// submitting the request from an interface without an OwningServiceInstanceId +// property. In some cases, this can be a transient condition that could be +// retried. +func (e Error) Unauthorized() bool { + return e.Code == http.StatusUnauthorized +} diff --git a/nmagent/internal/errors.go b/nmagent/internal/errors.go new file mode 100644 index 0000000000..11a13d6b85 --- /dev/null +++ b/nmagent/internal/errors.go @@ -0,0 +1,63 @@ +package internal + +import "net/http" + +// Error represents an internal sentinal error which can be defined as a +// constant. +type Error string + +func (e Error) Error() string { + return string(e) +} + +// ErrorSource is an indicator used as a header value to indicate the source of +// non-2xx status codes. +type ErrorSource int + +const ( + ErrorSourceInvalid ErrorSource = iota + ErrorSourceWireserver + ErrorSourceNMAgent +) + +// String produces the string equivalent for the ErrorSource type. +func (e ErrorSource) String() string { + switch e { + case ErrorSourceWireserver: + return "wireserver" + case ErrorSourceNMAgent: + return "nmagent" + case ErrorSourceInvalid: + return "" + default: + return "" + } +} + +// NewErrorSource produces an ErrorSource value from the provided string. Any +// unrecognized values will become the invalid type. +func NewErrorSource(es string) ErrorSource { + switch es { + case "wireserver": + return ErrorSourceWireserver + case "nmagent": + return ErrorSourceNMAgent + default: + return ErrorSourceInvalid + } +} + +const ( + HeaderErrorSource = "X-Error-Source" +) + +// GetErrorSource retrieves the error source from the provided HTTP headers. +func GetErrorSource(head http.Header) ErrorSource { + return NewErrorSource(head.Get(HeaderErrorSource)) +} + +// SetErrorSource sets the header value necessary for communicating the error +// source. +func SetErrorSource(head *http.Header, es ErrorSource) { + head.Set(HeaderErrorSource, es.String()) +} diff --git a/nmagent/internal/errors_test.go b/nmagent/internal/errors_test.go new file mode 100644 index 0000000000..43a2518cd4 --- /dev/null +++ b/nmagent/internal/errors_test.go @@ -0,0 +1,38 @@ +package internal + +import ( + "net/http" + "testing" +) + +func TestErrorSource(t *testing.T) { + esTests := []struct { + sub string + exp string + }{ + {"wireserver", "wireserver"}, + {"nmagent", "nmagent"}, + {"garbage", ""}, + {"", ""}, + } + + for _, test := range esTests { + test := test + t.Run(test.sub, func(t *testing.T) { + t.Parallel() + + // since this is intended for use with headers, this tests end-to-end + es := NewErrorSource(test.sub) + + head := http.Header{} + SetErrorSource(&head, es) + gotEs := GetErrorSource(head) + + got := gotEs.String() + + if test.exp != got { + t.Fatal("received value differs from expectation: exp:", test, "got:", got) + } + }) + } +} diff --git a/nmagent/internal/internal.go b/nmagent/internal/internal.go new file mode 100644 index 0000000000..88ec948672 --- /dev/null +++ b/nmagent/internal/internal.go @@ -0,0 +1,10 @@ +package internal + +const ( + HeaderContentType = "Content-Type" +) + +const ( + MimeJSON = "application/json" + MimeOctetStream = "application/octet-stream" +) diff --git a/nmagent/internal/retry.go b/nmagent/internal/retry.go new file mode 100644 index 0000000000..9491aea0e4 --- /dev/null +++ b/nmagent/internal/retry.go @@ -0,0 +1,125 @@ +package internal + +import ( + "context" + "errors" + "math" + "time" + + pkgerrors "github.com/pkg/errors" +) + +const ( + noDelay = 0 * time.Nanosecond +) + +const ( + ErrMaxAttempts = Error("maximum attempts reached") +) + +// TemporaryError is an error that can indicate whether it may be resolved with +// another attempt. +type TemporaryError interface { + error + Temporary() bool +} + +// Retrier is a construct for attempting some operation multiple times with a +// configurable backoff strategy. +type Retrier struct { + Cooldown CooldownFactory +} + +// Do repeatedly invokes the provided run function while the context remains +// active. It waits in between invocations of the provided functions by +// delegating to the provided Cooldown function. +func (r Retrier) Do(ctx context.Context, run func() error) error { + cooldown := r.Cooldown() + + for { + if err := ctx.Err(); err != nil { + // nolint:wrapcheck // no meaningful information can be added to this error + return err + } + + err := run() + if err != nil { + // check to see if it's temporary. + var tempErr TemporaryError + if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { + delay, err := cooldown() // nolint:govet // the shadow is intentional + if err != nil { + return pkgerrors.Wrap(err, "sleeping during retry") + } + time.Sleep(delay) + continue + } + + // since it's not temporary, it can't be retried, so... + return err + } + return nil + } +} + +// CooldownFunc is a function that will block when called. It is intended for +// use with retry logic. +type CooldownFunc func() (time.Duration, error) + +// CooldownFactory is a function that returns CooldownFuncs. It helps +// CooldownFuncs dispose of any accumulated state so that they function +// correctly upon successive uses. +type CooldownFactory func() CooldownFunc + +// Max provides a fixed limit for the number of times a subordinate cooldown +// function can be invoked. +func Max(limit int, factory CooldownFactory) CooldownFactory { + return func() CooldownFunc { + cooldown := factory() + count := 0 + return func() (time.Duration, error) { + if count >= limit { + return noDelay, ErrMaxAttempts + } + + delay, err := cooldown() + if err != nil { + return noDelay, err + } + count++ + return delay, nil + } + } +} + +// AsFastAsPossible is a Cooldown strategy that does not block, allowing retry +// logic to proceed as fast as possible. This is particularly useful in tests. +func AsFastAsPossible() CooldownFactory { + return func() CooldownFunc { + return func() (time.Duration, error) { + return noDelay, nil + } + } +} + +// Exponential provides an exponential increase the the base interval provided. +func Exponential(interval time.Duration, base int) CooldownFactory { + return func() CooldownFunc { + count := 0 + return func() (time.Duration, error) { + increment := math.Pow(float64(base), float64(count)) + delay := interval.Nanoseconds() * int64(increment) + count++ + return time.Duration(delay), nil + } + } +} + +// Fixed produced the same delay value upon each invocation. +func Fixed(delay time.Duration) CooldownFactory { + return func() CooldownFunc { + return func() (time.Duration, error) { + return delay, nil + } + } +} diff --git a/nmagent/internal/retry_example_test.go b/nmagent/internal/retry_example_test.go new file mode 100644 index 0000000000..c66bc194fb --- /dev/null +++ b/nmagent/internal/retry_example_test.go @@ -0,0 +1,68 @@ +package internal + +import ( + "fmt" + "time" +) + +func ExampleExponential() { + // this example details the common case where the powers of 2 are desired + cooldown := Exponential(1*time.Millisecond, 2)() + + for i := 0; i < 5; i++ { + got, err := cooldown() + if err != nil { + fmt.Println("received error during cooldown: err:", err) + return + } + + fmt.Println(got) + } + + // Output: + // 1ms + // 2ms + // 4ms + // 8ms + // 16ms +} + +func ExampleFixed() { + cooldown := Fixed(10 * time.Millisecond)() + + for i := 0; i < 5; i++ { + got, err := cooldown() + if err != nil { + fmt.Println("unexpected error cooling down: err", err) + return + } + fmt.Println(got) + + // Output: + // 10ms + // 10ms + // 10ms + // 10ms + // 10ms + } +} + +func ExampleMax() { + cooldown := Max(4, Fixed(10*time.Millisecond))() + + for i := 0; i < 5; i++ { + got, err := cooldown() + if err != nil { + fmt.Println("error cooling down:", err) + break + } + fmt.Println(got) + + // Output: + // 10ms + // 10ms + // 10ms + // 10ms + // error cooling down: maximum attempts reached + } +} diff --git a/nmagent/internal/retry_test.go b/nmagent/internal/retry_test.go new file mode 100644 index 0000000000..55824de38b --- /dev/null +++ b/nmagent/internal/retry_test.go @@ -0,0 +1,164 @@ +package internal + +import ( + "context" + "errors" + "testing" + "time" +) + +type TestError struct{} + +func (t TestError) Error() string { + return "oh no!" +} + +func (t TestError) Temporary() bool { + return true +} + +func TestBackoffRetry(t *testing.T) { + got := 0 + exp := 10 + + ctx := context.Background() + + rt := Retrier{ + Cooldown: AsFastAsPossible(), + } + + err := rt.Do(ctx, func() error { + if got < exp { + got++ + return TestError{} + } + return nil + }) + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if got < exp { + t.Error("unexpected number of invocations: got:", got, "exp:", exp) + } +} + +func TestBackoffRetryWithCancel(t *testing.T) { + got := 0 + exp := 5 + total := 10 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rt := Retrier{ + Cooldown: AsFastAsPossible(), + } + + err := rt.Do(ctx, func() error { + got++ + if got >= exp { + cancel() + } + + if got < total { + return TestError{} + } + return nil + }) + + if err == nil { + t.Error("expected context cancellation error, but received none") + } + + if got != exp { + t.Error("unexpected number of iterations: exp:", exp, "got:", got) + } +} + +func TestBackoffRetryUnretriableError(t *testing.T) { + rt := Retrier{ + Cooldown: AsFastAsPossible(), + } + + err := rt.Do(context.Background(), func() error { + return errors.New("boom") // nolint:goerr113 // it's just a test + }) + + if err == nil { + t.Fatal("expected an error, but none was returned") + } +} + +func TestFixed(t *testing.T) { + exp := 20 * time.Millisecond + + cooldown := Fixed(exp)() + + got, err := cooldown() + if err != nil { + t.Fatal("unexpected error invoking cooldown: err:", err) + } + + if got != exp { + t.Fatal("unexpected sleep duration: exp:", exp, "got:", got) + } +} + +func TestExp(t *testing.T) { + exp := 10 * time.Millisecond + base := 2 + + cooldown := Exponential(exp, base)() + + first, err := cooldown() + if err != nil { + t.Fatal("unexpected error invoking cooldown: err:", err) + } + + if first != exp { + t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", first) + } + + // ensure that the sleep increases + second, err := cooldown() + if err != nil { + t.Fatal("unexpected error on second invocation of cooldown: err:", err) + } + + if second < first { + t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", second) + } +} + +func TestMax(t *testing.T) { + exp := 10 + got := 0 + + // create a test sleep function + fn := func() CooldownFunc { + return func() (time.Duration, error) { + got++ + return 0 * time.Nanosecond, nil + } + } + + cooldown := Max(10, fn)() + + for i := 0; i < exp; i++ { + _, err := cooldown() + if err != nil { + t.Fatal("unexpected error from cooldown: err:", err) + } + } + + if exp != got { + t.Error("unexpected number of cooldown invocations: exp:", exp, "got:", got) + } + + // attempt one more, we expect an error + _, err := cooldown() + if err == nil { + t.Errorf("expected an error after %d invocations but received none", exp+1) + } +} diff --git a/nmagent/internal/validate.go b/nmagent/internal/validate.go new file mode 100644 index 0000000000..de6ebfa221 --- /dev/null +++ b/nmagent/internal/validate.go @@ -0,0 +1,18 @@ +package internal + +import ( + "fmt" + "strings" +) + +type ValidationError struct { + MissingFields []string +} + +func (v ValidationError) Error() string { + return fmt.Sprintf("missing fields: %s", strings.Join(v.MissingFields, ", ")) +} + +func (v ValidationError) IsEmpty() bool { + return len(v.MissingFields) == 0 +} diff --git a/nmagent/internal/wireserver.go b/nmagent/internal/wireserver.go new file mode 100644 index 0000000000..de4f2f1833 --- /dev/null +++ b/nmagent/internal/wireserver.go @@ -0,0 +1,184 @@ +package internal + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "strconv" + "strings" + + pkgerrors "github.com/pkg/errors" +) + +const ( + // nolint:gomnd // constantizing just obscures meaning here + _ int64 = 1 << (10 * iota) + kilobyte + // megabyte +) + +const ( + WirePrefix string = "/machine/plugins/?comp=nmagent&type=" + + // DefaultBufferSize is the maximum number of bytes read from Wireserver in + // the event that no Content-Length is provided. The responses are relatively + // small, so the smallest page size should be sufficient + DefaultBufferSize int64 = 4 * kilobyte + + // errors + ErrNoStatusCode = Error("no httpStatusCode property returned in Wireserver response") +) + +var _ http.RoundTripper = &WireserverTransport{} + +// WireserverResponse represents a raw response from Wireserver. +type WireserverResponse map[string]json.RawMessage + +// StatusCode extracts the embedded HTTP status code from the response from Wireserver. +func (w WireserverResponse) StatusCode() (int, error) { + if status, ok := w["httpStatusCode"]; ok { + var statusStr string + err := json.Unmarshal(status, &statusStr) + if err != nil { + return 0, pkgerrors.Wrap(err, "unmarshaling httpStatusCode from Wireserver") + } + + code, err := strconv.Atoi(statusStr) + if err != nil { + return code, pkgerrors.Wrap(err, "parsing http status code from wireserver") + } + return code, nil + } + return 0, ErrNoStatusCode +} + +// WireserverTransport is an http.RoundTripper that applies transformation +// rules to inbound requests necessary to make them compatible with Wireserver. +type WireserverTransport struct { + Transport http.RoundTripper +} + +// RoundTrip executes arbitrary HTTP requests against Wireserver while applying +// the necessary transformation rules to make such requests acceptable to +// Wireserver. +func (w *WireserverTransport) RoundTrip(inReq *http.Request) (*http.Response, error) { + // RoundTrippers are not allowed to modify the request, so we clone it here. + // We need to extract the context from the request first since this is _not_ + // cloned. The dependent Wireserver request should have the same deadline and + // cancellation properties as the inbound request though, hence the reuse. + ctx := inReq.Context() + req := inReq.Clone(ctx) + + // the original path of the request must be prefixed with wireserver's path + path := WirePrefix + if req.URL.Path != "" { + path += req.URL.Path[1:] + } + + // the query string from the request must have its constituent parts (?,=,&) + // transformed to slashes and appended to the query + if req.URL.RawQuery != "" { + query := req.URL.RawQuery + query = strings.ReplaceAll(query, "?", "/") + query = strings.ReplaceAll(query, "=", "/") + query = strings.ReplaceAll(query, "&", "/") + path += "/" + query + } + + req.URL.Path = path + + // wireserver cannot tolerate PUT requests, so it's necessary to transform + // those to POSTs + if req.Method == http.MethodPut { + req.Method = http.MethodPost + } + + // all POST requests (and by extension, PUT) must have a non-nil body + if req.Method == http.MethodPost && req.Body == nil { + req.Body = io.NopCloser(strings.NewReader("")) + } + + // execute the request to the downstream transport + resp, err := w.Transport.RoundTrip(req) + if err != nil { + return nil, pkgerrors.Wrap(err, "executing request to wireserver") + } + + if resp.StatusCode != http.StatusOK { + // something happened at Wireserver, so set a header implicating Wireserver + // and hand the response back up + SetErrorSource(&resp.Header, ErrorSourceWireserver) + return resp, nil + } + + // at this point we're definitely going to modify the body, so we want to + // make sure we close the original request's body, since we're going to + // replace it + defer func(body io.ReadCloser) { + body.Close() + }(resp.Body) + + // buffer the entire response from Wireserver + clen := resp.ContentLength + if clen < 0 { + clen = DefaultBufferSize + } + + body := make([]byte, clen) + bodyReader := io.LimitReader(resp.Body, clen) + + numRead, err := io.ReadFull(bodyReader, body) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { + return nil, pkgerrors.Wrap(err, "reading response from wireserver") + } + // it's entirely possible at this point that we read less than we allocated, + // so trim the slice back for decoding + body = body[:numRead] + + // set the content length properly in case it wasn't set. If it was, this is + // effectively a no-op + resp.ContentLength = int64(numRead) + + // it's unclear whether Wireserver sets Content-Type appropriately, so we + // attempt to decode it first and surface it otherwise. + var wsResp WireserverResponse + err = json.Unmarshal(body, &wsResp) + if err != nil { + // probably not JSON, so figure out what it is, pack it up, and surface it + // unmodified + resp.Header.Set(HeaderContentType, http.DetectContentType(body)) + resp.Body = io.NopCloser(bytes.NewReader(body)) + + // nolint:nilerr // we effectively "fix" this error because it's expected + return resp, nil + } + + // we know that it's JSON now, so communicate that upwards + resp.Header.Set(HeaderContentType, MimeJSON) + + // set the response status code with the *real* status code + realCode, err := wsResp.StatusCode() + if err != nil { + return resp, pkgerrors.Wrap(err, "retrieving status code from wireserver response") + } + + // add the advisory header stating that any HTTP Status from here out is from + // NMAgent + SetErrorSource(&resp.Header, ErrorSourceNMAgent) + + resp.StatusCode = realCode + + // re-encode the body and re-attach it to the response + delete(wsResp, "httpStatusCode") // TODO(timraymond): concern of the response + + outBody, err := json.Marshal(wsResp) + if err != nil { + return resp, pkgerrors.Wrap(err, "re-encoding json response from wireserver") + } + + resp.Body = io.NopCloser(bytes.NewReader(outBody)) + + return resp, nil +} diff --git a/nmagent/internal/wireserver_test.go b/nmagent/internal/wireserver_test.go new file mode 100644 index 0000000000..c9c4cdeb4d --- /dev/null +++ b/nmagent/internal/wireserver_test.go @@ -0,0 +1,393 @@ +package internal + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" +) + +type TestTripper struct { + // TODO(timraymond): this entire struct is duplicated + RoundTripF func(*http.Request) (*http.Response, error) +} + +func (t *TestTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripF(req) +} + +func TestWireserverTransportPathTransform(t *testing.T) { + // Wireserver introduces specific rules on how requests should be + // transformed. This test ensures we got those correct. + + pathTests := []struct { + name string + method string + sub string + exp string + }{ + { + "happy path", + http.MethodGet, + "/test/path", + "/machine/plugins/?comp=nmagent&type=test/path", + }, + { + "empty", + http.MethodGet, + "", + "/machine/plugins/?comp=nmagent&type=", + }, + { + "monopath", + http.MethodGet, + "/foo", + "/machine/plugins/?comp=nmagent&type=foo", + }, + } + + for _, test := range pathTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(r *http.Request) (*http.Response, error) { + got = r.URL.Path + rr := httptest.NewRecorder() + rr.WriteHeader(http.StatusOK) + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + return rr.Result(), nil + }, + }, + }, + } + + // execute + + //nolint:noctx // just a test + req, err := http.NewRequest(test.method, test.sub, http.NoBody) + if err != nil { + t.Fatal("error creating new request: err:", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal("unexpected error submitting request: err:", err) + } + defer resp.Body.Close() + + // assert + if got != test.exp { + t.Error("received path differs from expectation: exp:", test.exp, "got:", got) + } + }) + } +} + +func TestWireserverTransportStatusTransform(t *testing.T) { + // Wireserver only responds with 200 or 400 and embeds the actual status code + // in JSON. The Transport should correct this and return the actual status as + // an actual status + + statusTests := []struct { + name string + response map[string]interface{} + expBody map[string]interface{} + expStatus int + }{ + { + "401", + map[string]interface{}{ + "httpStatusCode": "401", + }, + map[string]interface{}{}, + http.StatusUnauthorized, + }, + { + "200 with body", + map[string]interface{}{ + "httpStatusCode": "200", + "some": "data", + }, + map[string]interface{}{ + "some": "data", + }, + http.StatusOK, + }, + } + + for _, test := range statusTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(r *http.Request) (*http.Response, error) { + rr := httptest.NewRecorder() + // mimic Wireserver handing back a 200 regardless: + rr.WriteHeader(http.StatusOK) + + err := json.NewEncoder(rr).Encode(&test.response) + if err != nil { + return nil, errors.Wrap(err, "encoding json response") + } + + return rr.Result(), nil + }, + }, + }, + } + + // execute + + // nolint:noctx // just a test + req, err := http.NewRequest(http.MethodGet, "/test/path", http.NoBody) + if err != nil { + t.Fatal("error creating new request: err:", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal("unexpected error submitting request: err:", err) + } + defer resp.Body.Close() + + // assert + gotStatus := resp.StatusCode + if gotStatus != test.expStatus { + t.Errorf("status codes differ: exp: (%d) %s: got (%d): %s", test.expStatus, http.StatusText(test.expStatus), gotStatus, http.StatusText(gotStatus)) + } + + var gotBody map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&gotBody) + if err != nil { + t.Fatal("unexpected error decoding json body: err:", err) + } + + if !cmp.Equal(test.expBody, gotBody) { + t.Error("received body differs from expected: diff:", cmp.Diff(test.expBody, gotBody)) + } + }) + } +} + +func TestWireserverTransportPutPost(t *testing.T) { + // wireserver can't tolerate PUT requests, so they must be transformed to POSTs + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.Method + rr := httptest.NewRecorder() + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + req, err := http.NewRequest(http.MethodPut, "/test/path", http.NoBody) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + defer resp.Body.Close() + + exp := http.MethodPost + if got != exp { + t.Error("unexpected status: exp:", exp, "got:", got) + } +} + +func TestWireserverTransportPostBody(t *testing.T) { + // all PUT and POST requests must have an empty string body + t.Parallel() + + bodyIsNil := false + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + bodyIsNil = req.Body == nil + rr := httptest.NewRecorder() + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + // PUT request + req, err := http.NewRequest(http.MethodPut, "/test/path", http.NoBody) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + defer resp.Body.Close() + + if bodyIsNil { + t.Error("downstream request body to wireserver was nil, but not expected to be") + } + + // POST request + // nolint:noctx // just a test + req, err = http.NewRequest(http.MethodPost, "/test/path", http.NoBody) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + resp, err = client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + defer resp.Body.Close() + + if bodyIsNil { + t.Error("downstream request body to wireserver was nil, but not expected to be") + } +} + +func TestWireserverTransportQuery(t *testing.T) { + // the query string must have its constituent parts converted to slashes and + // appended to the path + t.Parallel() + + var got string + client := &http.Client{ + Transport: &WireserverTransport{ + Transport: &TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + got = req.URL.Path + rr := httptest.NewRecorder() + _, _ = rr.WriteString(`{"httpStatusCode": "200"}`) + rr.WriteHeader(http.StatusOK) + return rr.Result(), nil + }, + }, + }, + } + + // nolint:noctx // just a test + req, err := http.NewRequest(http.MethodPut, "/test/path?api-version=1234&foo=bar", http.NoBody) + if err != nil { + t.Fatal("unexpected error creating http request: err:", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal("error submitting request: err:", err) + } + defer resp.Body.Close() + + exp := "/machine/plugins/?comp=nmagent&type=test/path/api-version/1234/foo/bar" + if got != exp { + t.Error("received request differs from expectation: got:", got, "want:", exp) + } +} + +func TestWireserverResponse(t *testing.T) { + wsRespTests := []struct { + name string + resp string + exp int + shouldErr bool + }{ + { + "empty", + "{}", + 0, + true, + }, + { + "happy path", + `{ + "httpStatusCode": "401" + }`, + 401, + false, + }, + { + "missing code", + `{ + "httpStatusCode": "" + }`, + 0, + true, + }, + { + "other stuff", + `{ + "httpStatusCode": "201", + "other": "stuff" + }`, + 201, + false, + }, + { + "not a string", + `{ + "httpStatusCode": 201, + "other": "stuff" + }`, + 0, + true, + }, + { + "processing", + `{ + "httpStatusCode": "102", + "other": "stuff" + }`, + 102, + false, + }, + } + + for _, test := range wsRespTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var resp WireserverResponse + err := json.Unmarshal([]byte(test.resp), &resp) + if err != nil { + t.Fatal("unexpected unmarshaling error: err:", err) + } + + got, err := resp.StatusCode() + if err != nil && !test.shouldErr { + t.Fatal("unexpected error retrieving status code: err:", err) + } + + if err == nil && test.shouldErr { + t.Fatal("no error received when one was expected") + } + + if got != test.exp { + t.Error("received incorrect code: got:", got, "want:", test.exp) + } + }) + } +} diff --git a/nmagent/nmagent_test.go b/nmagent/nmagent_test.go new file mode 100644 index 0000000000..464a1528c0 --- /dev/null +++ b/nmagent/nmagent_test.go @@ -0,0 +1,129 @@ +package nmagent_test + +import ( + "bytes" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-container-networking/nmagent" +) + +func TestErrorTemp(t *testing.T) { + errorTests := []struct { + name string + err nmagent.Error + shouldTemp bool + }{ + { + "regular", + nmagent.Error{ + Code: http.StatusInternalServerError, + }, + false, + }, + { + "processing", + nmagent.Error{ + Code: http.StatusProcessing, + }, + true, + }, + { + "unauthorized", + nmagent.Error{ + Code: http.StatusUnauthorized, + }, + false, + }, + } + + for _, test := range errorTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + if test.err.Temporary() && !test.shouldTemp { + t.Fatal("test was temporary and not expected to be") + } + + if !test.err.Temporary() && test.shouldTemp { + t.Fatal("test was not temporary but expected to be") + } + }) + } +} + +func TestContentErrorNew(t *testing.T) { + errTests := []struct { + name string + body io.Reader + limit int64 + contentType string + exp string + shouldMakeErr bool + }{ + { + "empty", + strings.NewReader(""), + 0, + "text/plain", + "unexpected content type \"text/plain\": body: ", + true, + }, + { + "happy path", + strings.NewReader("random text"), + 11, + "text/plain", + "unexpected content type \"text/plain\": body: random text", + true, + }, + { + // if the body is an octet stream, it's entirely possible that it's + // unprintable garbage. This ensures that we just print the length + "octets", + bytes.NewReader([]byte{0xde, 0xad, 0xbe, 0xef}), + 4, + "application/octet-stream", + "unexpected content type \"application/octet-stream\": body length: 4", + true, + }, + { + // even if the length is wrong, we still want to return as much data as + // we can for debugging + "wrong len", + bytes.NewReader([]byte{0xde, 0xad, 0xbe, 0xef}), + 8, + "application/octet-stream", + "unexpected content type \"application/octet-stream\": body length: 4", + true, + }, + } + + for _, test := range errTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := nmagent.NewContentError(test.contentType, test.body, test.limit) + + var e nmagent.ContentError + wasContentErr := errors.As(err, &e) + if !wasContentErr && test.shouldMakeErr { + t.Fatalf("error was not a ContentError") + } + + if wasContentErr && !test.shouldMakeErr { + t.Fatalf("received a ContentError when it was not expected") + } + + got := err.Error() + if got != test.exp { + t.Error("unexpected error message: got:", got, "exp:", test.exp) + } + }) + } +} diff --git a/nmagent/requests.go b/nmagent/requests.go new file mode 100644 index 0000000000..cf2d3d80c6 --- /dev/null +++ b/nmagent/requests.go @@ -0,0 +1,284 @@ +package nmagent + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "unicode" + + "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/pkg/errors" +) + +// Request represents an abstracted HTTP request, capable of validating itself, +// producing a valid Path, Body, and its Method. +type Request interface { + // Validate should ensure that the request is valid to submit + Validate() error + + // Path should produce a URL path, complete with any URL parameters + // interpolated + Path() string + + // Body produces the HTTP request body necessary to submit the request + Body() (io.Reader, error) + + // Method returns the HTTP Method to be used for the request. + Method() string +} + +var _ Request = &PutNetworkContainerRequest{} + +// PutNetworkContainerRequest is a collection of parameters necessary to create +// a new network container +type PutNetworkContainerRequest struct { + ID string `json:"networkContainerID"` // the id of the network container + VNetID string `json:"virtualNetworkID"` // the id of the customer's vnet + + // Version is the new network container version + Version uint64 `json:"version"` + + // SubnetName is the name of the delegated subnet. This is used to + // authenticate the request. The list of ipv4addresses must be contained in + // the subnet's prefix. + SubnetName string `json:"subnetName"` + + // IPv4 addresses in the customer virtual network that will be assigned to + // the interface. + IPv4Addrs []string `json:"ipV4Addresses"` + + Policies []Policy `json:"policies"` // policies applied to the network container + + // VlanID is used to distinguish Network Containers with duplicate customer + // addresses. "0" is considered a default value by the API. + VlanID int `json:"vlanId"` + + GREKey uint16 `json:"greKey"` + + // AuthenticationToken is the base64 security token for the subnet containing + // the Network Container addresses + AuthenticationToken string `json:"-"` + + // PrimaryAddress is the primary customer address of the interface in the + // management VNet + PrimaryAddress string `json:"-"` +} + +// Body marshals the JSON fields of the request and produces an Reader intended +// for use with an HTTP request +func (p *PutNetworkContainerRequest) Body() (io.Reader, error) { + body, err := json.Marshal(p) + if err != nil { + return nil, errors.Wrap(err, "marshaling PutNetworkContainerRequest") + } + + return bytes.NewReader(body), nil +} + +// Method returns the HTTP method for this request type +func (p *PutNetworkContainerRequest) Method() string { + return http.MethodPost +} + +// Path returns the URL path necessary to submit this PutNetworkContainerRequest +func (p *PutNetworkContainerRequest) Path() string { + const PutNCRequestPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1" + return fmt.Sprintf(PutNCRequestPath, p.PrimaryAddress, p.ID, p.AuthenticationToken) +} + +// Validate ensures that all of the required parameters of the request have +// been filled out properly prior to submission to NMAgent +func (p *PutNetworkContainerRequest) Validate() error { + err := internal.ValidationError{} + + if p.Version == 0 { + err.MissingFields = append(err.MissingFields, "Version") + } + + if p.SubnetName == "" { + err.MissingFields = append(err.MissingFields, "SubnetName") + } + + if len(p.IPv4Addrs) == 0 { + err.MissingFields = append(err.MissingFields, "IPv4Addrs") + } + + if p.VNetID == "" { + err.MissingFields = append(err.MissingFields, "VNetID") + } + + if err.IsEmpty() { + return nil + } + return err +} + +type Policy struct { + ID string + Type string +} + +// MarshalJson encodes policies as a JSON string, separated by a comma. This +// specific format is requested by the NMAgent documentation +func (p Policy) MarshalJSON() ([]byte, error) { + out := bytes.NewBufferString(p.ID) + out.WriteString(", ") + out.WriteString(p.Type) + + outStr := out.String() + // nolint:wrapcheck // wrapping this error provides no useful information + return json.Marshal(outStr) +} + +// UnmarshalJSON decodes a JSON-encoded policy string +func (p *Policy) UnmarshalJSON(in []byte) error { + const expectedNumParts = 2 + + var raw string + err := json.Unmarshal(in, &raw) + if err != nil { + return errors.Wrap(err, "decoding policy") + } + + parts := strings.Split(raw, ",") + if len(parts) != expectedNumParts { + return errors.New("policies must be two comma-separated values") + } + + p.ID = strings.TrimFunc(parts[0], unicode.IsSpace) + p.Type = strings.TrimFunc(parts[1], unicode.IsSpace) + + return nil +} + +var _ Request = JoinNetworkRequest{} + +type JoinNetworkRequest struct { + NetworkID string `validate:"presence" json:"-"` // the customer's VNet ID +} + +// Path constructs a URL path for invoking a JoinNetworkRequest using the +// provided parameters +func (j JoinNetworkRequest) Path() string { + const JoinNetworkPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + return fmt.Sprintf(JoinNetworkPath, j.NetworkID) +} + +// Body returns nothing, because JoinNetworkRequest has no request body +func (j JoinNetworkRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method returns the HTTP request method to submit a JoinNetworkRequest +func (j JoinNetworkRequest) Method() string { + return http.MethodPost +} + +// Validate ensures that the provided parameters of the request are valid +func (j JoinNetworkRequest) Validate() error { + err := internal.ValidationError{} + + if j.NetworkID == "" { + err.MissingFields = append(err.MissingFields, "NetworkID") + } + + if err.IsEmpty() { + return nil + } + return err +} + +var _ Request = DeleteContainerRequest{} + +// DeleteContainerRequest represents all information necessary to request that +// NMAgent delete a particular network container +type DeleteContainerRequest struct { + NCID string `json:"-"` // the Network Container ID + + // PrimaryAddress is the primary customer address of the interface in the + // management VNET + PrimaryAddress string `json:"-"` + AuthenticationToken string `json:"-"` +} + +// Path returns the path for submitting a DeleteContainerRequest with +// parameters interpolated correctly +func (d DeleteContainerRequest) Path() string { + const DeleteNCPath string = "/NetworkManagement/interfaces/%s/networkContainers/%s/authenticationToken/%s/api-version/1/method/DELETE" + return fmt.Sprintf(DeleteNCPath, d.PrimaryAddress, d.NCID, d.AuthenticationToken) +} + +// Body returns nothing, because DeleteContainerRequests have no HTTP body +func (d DeleteContainerRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method returns the HTTP method required to submit a DeleteContainerRequest +func (d DeleteContainerRequest) Method() string { + return http.MethodPost +} + +// Validate ensures that the DeleteContainerRequest has the correct information +// to submit the request +func (d DeleteContainerRequest) Validate() error { + err := internal.ValidationError{} + + if d.NCID == "" { + err.MissingFields = append(err.MissingFields, "NCID") + } + + if d.PrimaryAddress == "" { + err.MissingFields = append(err.MissingFields, "PrimaryAddress") + } + + if d.AuthenticationToken == "" { + err.MissingFields = append(err.MissingFields, "AuthenticationToken") + } + + if err.IsEmpty() { + return nil + } + return err +} + +var _ Request = GetNetworkConfigRequest{} + +// GetNetworkConfigRequest is a collection of necessary information for +// submitting a request for a customer's network configuration +type GetNetworkConfigRequest struct { + VNetID string `json:"-"` // the customer's virtual network ID +} + +// Path produces a URL path used to submit a request +func (g GetNetworkConfigRequest) Path() string { + const GetNetworkConfigPath string = "/NetworkManagement/joinedVirtualNetworks/%s/api-version/1" + return fmt.Sprintf(GetNetworkConfigPath, g.VNetID) +} + +// Body returns nothing because GetNetworkConfigRequest has no HTTP request +// body +func (g GetNetworkConfigRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method returns the HTTP method required to submit a GetNetworkConfigRequest +func (g GetNetworkConfigRequest) Method() string { + return http.MethodGet +} + +// Validate ensures that the request is complete and the parameters are correct +func (g GetNetworkConfigRequest) Validate() error { + err := internal.ValidationError{} + + if g.VNetID == "" { + err.MissingFields = append(err.MissingFields, "VNetID") + } + + if err.IsEmpty() { + return nil + } + return err +} diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go new file mode 100644 index 0000000000..d1e3a863e4 --- /dev/null +++ b/nmagent/requests_test.go @@ -0,0 +1,402 @@ +package nmagent_test + +import ( + "encoding/json" + "testing" + + "github.com/Azure/azure-container-networking/nmagent" + "github.com/google/go-cmp/cmp" +) + +func TestPolicyMarshal(t *testing.T) { + policyTests := []struct { + name string + policy nmagent.Policy + exp string + }{ + { + "basic", + nmagent.Policy{ + ID: "policyID1", + Type: "type1", + }, + "\"policyID1, type1\"", + }, + } + + for _, test := range policyTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + got, err := json.Marshal(test.policy) + if err != nil { + t.Fatal("unexpected err marshaling policy: err", err) + } + + if string(got) != test.exp { + t.Errorf("marshaled policy does not match expectation: got: %q: exp: %q", string(got), test.exp) + } + + var enc nmagent.Policy + err = json.Unmarshal(got, &enc) + if err != nil { + t.Fatal("unexpected error unmarshaling: err:", err) + } + + if !cmp.Equal(enc, test.policy) { + t.Error("re-encoded policy differs from expectation: diff:", cmp.Diff(enc, test.policy)) + } + }) + } +} + +func TestDeleteContainerRequestValidation(t *testing.T) { + dcrTests := []struct { + name string + req nmagent.DeleteContainerRequest + shouldBeValid bool + }{ + { + "empty", + nmagent.DeleteContainerRequest{}, + false, + }, + { + "missing ncid", + nmagent.DeleteContainerRequest{ + PrimaryAddress: "10.0.0.1", + AuthenticationToken: "swordfish", + }, + false, + }, + { + "missing primary address", + nmagent.DeleteContainerRequest{ + NCID: "00000000-0000-0000-0000-000000000000", + AuthenticationToken: "swordfish", + }, + false, + }, + { + "missing auth token", + nmagent.DeleteContainerRequest{ + NCID: "00000000-0000-0000-0000-000000000000", + PrimaryAddress: "10.0.0.1", + }, + false, + }, + } + + for _, test := range dcrTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("unexpected validation errors: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected request to be invalid but wasn't") + } + }) + } +} + +func TestJoinNetworkRequestPath(t *testing.T) { + jnr := nmagent.JoinNetworkRequest{ + NetworkID: "00000000-0000-0000-0000-000000000000", + } + + exp := "/NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1" + if jnr.Path() != exp { + t.Error("unexpected path: exp:", exp, "got:", jnr.Path()) + } +} + +func TestJoinNetworkRequestValidate(t *testing.T) { + validateRequest := []struct { + name string + req nmagent.JoinNetworkRequest + shouldBeValid bool + }{ + { + "invalid", + nmagent.JoinNetworkRequest{ + NetworkID: "", + }, + false, + }, + { + "valid", + nmagent.JoinNetworkRequest{ + NetworkID: "00000000-0000-0000-0000-000000000000", + }, + true, + }, + } + + for _, test := range validateRequest { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("unexpected error validating: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected request to be invalid but wasn't") + } + }) + } +} + +func TestGetNetworkConfigRequestPath(t *testing.T) { + pathTests := []struct { + name string + req nmagent.GetNetworkConfigRequest + exp string + }{ + { + "happy path", + nmagent.GetNetworkConfigRequest{ + VNetID: "00000000-0000-0000-0000-000000000000", + }, + "/NetworkManagement/joinedVirtualNetworks/00000000-0000-0000-0000-000000000000/api-version/1", + }, + } + + for _, test := range pathTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + if got := test.req.Path(); got != test.exp { + t.Error("unexpected path: exp:", test.exp, "got:", got) + } + }) + } +} + +func TestGetNetworkConfigRequestValidate(t *testing.T) { + validateTests := []struct { + name string + req nmagent.GetNetworkConfigRequest + shouldBeValid bool + }{ + { + "happy path", + nmagent.GetNetworkConfigRequest{ + VNetID: "00000000-0000-0000-0000-000000000000", + }, + true, + }, + { + "empty", + nmagent.GetNetworkConfigRequest{}, + false, + }, + } + + for _, test := range validateTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("expected request to be valid but wasn't: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected error to be invalid but wasn't") + } + }) + } +} + +func TestPutNetworkContainerRequestPath(t *testing.T) { + pathTests := []struct { + name string + req nmagent.PutNetworkContainerRequest + exp string + }{ + { + "happy path", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + "10.0.0.3", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + "/NetworkManagement/interfaces/10.0.0.1/networkContainers/00000000-0000-0000-0000-000000000000/authenticationToken/swordfish/api-version/1", + }, + } + + for _, test := range pathTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + if got := test.req.Path(); got != test.exp { + t.Error("path differs from expectation: exp:", test.exp, "got:", got) + } + }) + } +} + +func TestPutNetworkContainerRequestValidate(t *testing.T) { + validationTests := []struct { + name string + req nmagent.PutNetworkContainerRequest + shouldBeValid bool + }{ + { + "empty", + nmagent.PutNetworkContainerRequest{}, + false, + }, + { + "happy", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + "10.0.0.3", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + true, + }, + { + "missing IPv4Addrs", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{}, // the important part + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + false, + }, + { + "missing subnet name", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(12345), + SubnetName: "", // the important part of the test + IPv4Addrs: []string{ + "10.0.0.2", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + false, + }, + { + "missing version", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "11111111-1111-1111-1111-111111111111", + Version: uint64(0), // the important part of the test + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + false, + }, + { + "missing vnet id", + nmagent.PutNetworkContainerRequest{ + ID: "00000000-0000-0000-0000-000000000000", + VNetID: "", // the important part + Version: uint64(12345), + SubnetName: "foo", + IPv4Addrs: []string{ + "10.0.0.2", + }, + Policies: []nmagent.Policy{ + { + ID: "Foo", + Type: "Bar", + }, + }, + VlanID: 0, + AuthenticationToken: "swordfish", + PrimaryAddress: "10.0.0.1", + }, + false, + }, + } + + for _, test := range validationTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := test.req.Validate() + if err != nil && test.shouldBeValid { + t.Fatal("unexpected error validating: err:", err) + } + + if err == nil && !test.shouldBeValid { + t.Fatal("expected validation error but received none") + } + }) + } +} diff --git a/nmagent/responses.go b/nmagent/responses.go new file mode 100644 index 0000000000..48705eff17 --- /dev/null +++ b/nmagent/responses.go @@ -0,0 +1,21 @@ +package nmagent + +type VirtualNetwork struct { + CNetSpace string `json:"cnetSpace"` + DefaultGateway string `json:"defaultGateway"` + DNSServers []string `json:"dnsServers"` + Subnets []Subnet `json:"subnets"` + VNetSpace string `json:"vnetSpace"` + VNetVersion string `json:"vnetVersion"` +} + +type Subnet struct { + AddressPrefix string `json:"addressPrefix"` + SubnetName string `json:"subnetName"` + Tags []Tag `json:"tags"` +} + +type Tag struct { + Name string `json:"name"` + Type string `json:"type"` // the type of the tag (e.g. "System" or "Custom") +}