From 6dc8bd021c9d97d2c8a5e15115e73744a39fd47a Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Thu, 19 Jan 2023 10:26:49 -0700 Subject: [PATCH] chore: consolidate all CLI configuration (#233) This is a port of https://github.com/GoogleCloudPlatform/cloud-sql-proxy/pull/1563. --- cmd/root.go | 93 +++++--------- cmd/root_test.go | 212 ++++++++++++++------------------ internal/proxy/internal_test.go | 79 ++++++++++++ internal/proxy/other_test.go | 11 -- internal/proxy/proxy.go | 70 +++++++++-- 5 files changed, 261 insertions(+), 204 deletions(-) create mode 100644 internal/proxy/internal_test.go diff --git a/cmd/root.go b/cmd/root.go index 1a9cce0e..a576e8ef 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -82,30 +82,10 @@ func Execute() { // Command represents an invocation of the AlloyDB Auth Proxy. type Command struct { *cobra.Command - conf *proxy.Config - logger alloydb.Logger - dialer alloydb.Dialer - - cleanup func() error - disableTraces bool - telemetryTracingSampleRate int - disableMetrics bool - telemetryProject string - telemetryPrefix string - prometheus bool - prometheusNamespace string - healthCheck bool - httpAddress string - httpPort string - quiet bool - otherUserAgents string - - // impersonationChain is a comma separated list of one or more service - // accounts. The first entry in the chain is the impersonation target. Any - // additional service accounts after the target are delegates. The - // roles/iam.serviceAccountTokenCreator must be configured for each account - // that will be impersonated. - impersonationChain string + conf *proxy.Config + logger alloydb.Logger + dialer alloydb.Dialer + cleanup func() error } // Option is a function that configures a Command. @@ -337,7 +317,7 @@ func NewCommand(opts ...Option) *Command { if c.conf.StructuredLogs { c.logger, c.cleanup = log.NewStructuredLogger() } - if c.quiet { + if c.conf.Quiet { c.logger = log.NewStdLogger(io.Discard, os.Stderr) } err := parseConfig(c, c.conf, args) @@ -356,7 +336,7 @@ func NewCommand(opts ...Option) *Command { pflags := cmd.PersistentFlags() // Global-only flags - pflags.StringVar(&c.otherUserAgents, "user-agent", "", + pflags.StringVar(&c.conf.OtherUserAgents, "user-agent", "", "Space separated list of additional user agents, e.g. cloud-sql-proxy-operator/0.0.1") pflags.StringVarP(&c.conf.Token, "token", "t", "", "Bearer token used for authorization.") @@ -384,30 +364,30 @@ the maximum time has passed. Defaults to 0s.`) pflags.StringVar(&c.conf.FUSETempDir, "fuse-tmp-dir", filepath.Join(os.TempDir(), "alloydb-tmp"), "Temp dir for Unix sockets created with FUSE") - pflags.StringVar(&c.impersonationChain, "impersonate-service-account", "", + pflags.StringVar(&c.conf.ImpersonationChain, "impersonate-service-account", "", `Comma separated list of service accounts to impersonate. Last value +is the target account.`) - cmd.PersistentFlags().BoolVar(&c.quiet, "quiet", false, "Log error messages only") + cmd.PersistentFlags().BoolVar(&c.conf.Quiet, "quiet", false, "Log error messages only") - pflags.StringVar(&c.telemetryProject, "telemetry-project", "", + pflags.StringVar(&c.conf.TelemetryProject, "telemetry-project", "", "Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.") - pflags.BoolVar(&c.disableTraces, "disable-traces", false, + pflags.BoolVar(&c.conf.DisableTraces, "disable-traces", false, "Disable Cloud Trace integration (used with telemetry-project)") - pflags.IntVar(&c.telemetryTracingSampleRate, "telemetry-sample-rate", 10_000, + pflags.IntVar(&c.conf.TelemetryTracingSampleRate, "telemetry-sample-rate", 10_000, "Configure the denominator of the probabilistic sample rate of traces sent to Cloud Trace\n(e.g., 10,000 traces 1/10,000 calls).") - pflags.BoolVar(&c.disableMetrics, "disable-metrics", false, + pflags.BoolVar(&c.conf.DisableMetrics, "disable-metrics", false, "Disable Cloud Monitoring integration (used with telemetry-project)") - pflags.StringVar(&c.telemetryPrefix, "telemetry-prefix", "", + pflags.StringVar(&c.conf.TelemetryPrefix, "telemetry-prefix", "", "Prefix to use for Cloud Monitoring metrics.") - pflags.BoolVar(&c.prometheus, "prometheus", false, + pflags.BoolVar(&c.conf.Prometheus, "prometheus", false, "Enable Prometheus HTTP endpoint /metrics") - pflags.StringVar(&c.prometheusNamespace, "prometheus-namespace", "", + pflags.StringVar(&c.conf.PrometheusNamespace, "prometheus-namespace", "", "Use the provided Prometheus namespace for metrics") - pflags.StringVar(&c.httpAddress, "http-address", "localhost", + pflags.StringVar(&c.conf.HTTPAddress, "http-address", "localhost", "Address for Prometheus and health check server") - pflags.StringVar(&c.httpPort, "http-port", "9090", + pflags.StringVar(&c.conf.HTTPPort, "http-port", "9090", "Port for the Prometheus server to use") - pflags.BoolVar(&c.healthCheck, "health-check", false, + pflags.BoolVar(&c.conf.HealthCheck, "health-check", false, `Enables HTTP endpoints /startup, /liveness, and /readiness that report on the proxy's health. Endpoints are available on localhost only. Uses the port specified by the http-port flag.`) @@ -521,23 +501,10 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { } if userHasSet("user-agent") { - defaultUserAgent += " " + cmd.otherUserAgents + defaultUserAgent += " " + cmd.conf.OtherUserAgents conf.UserAgent = defaultUserAgent } - if cmd.impersonationChain != "" { - accts := strings.Split(cmd.impersonationChain, ",") - conf.ImpersonateTarget = accts[0] - // Assign delegates if the chain is more than one account. Delegation - // goes from last back towards target, e.g., With sa1,sa2,sa3, sa3 - // delegates to sa2, which impersonates the target sa1. - if l := len(accts); l > 1 { - for i := l - 1; i > 0; i-- { - conf.ImpersonateDelegates = append(conf.ImpersonateDelegates, accts[i]) - } - } - } - var ics []proxy.InstanceConnConfig for _, a := range args { // split into instance uri and query parameters @@ -615,12 +582,12 @@ func runSignalWrapper(cmd *Command) error { // Configure collectors before the proxy has started to ensure we are // collecting metrics before *ANY* AlloyDB Admin API calls are made. - enableMetrics := !cmd.disableMetrics - enableTraces := !cmd.disableTraces - if cmd.telemetryProject != "" && (enableMetrics || enableTraces) { + enableMetrics := !cmd.conf.DisableMetrics + enableTraces := !cmd.conf.DisableTraces + if cmd.conf.TelemetryProject != "" && (enableMetrics || enableTraces) { sd, err := stackdriver.NewExporter(stackdriver.Options{ - ProjectID: cmd.telemetryProject, - MetricPrefix: cmd.telemetryPrefix, + ProjectID: cmd.conf.TelemetryProject, + MetricPrefix: cmd.conf.TelemetryPrefix, }) if err != nil { return err @@ -632,7 +599,7 @@ func runSignalWrapper(cmd *Command) error { } } if enableTraces { - s := trace.ProbabilitySampler(1 / float64(cmd.telemetryTracingSampleRate)) + s := trace.ProbabilitySampler(1 / float64(cmd.conf.TelemetryTracingSampleRate)) trace.ApplyConfig(trace.Config{DefaultSampler: s}) trace.RegisterExporter(sd) } @@ -646,10 +613,10 @@ func runSignalWrapper(cmd *Command) error { needsHTTPServer bool mux = http.NewServeMux() ) - if cmd.prometheus { + if cmd.conf.Prometheus { needsHTTPServer = true e, err := prometheus.NewExporter(prometheus.Options{ - Namespace: cmd.prometheusNamespace, + Namespace: cmd.conf.PrometheusNamespace, }) if err != nil { return err @@ -704,10 +671,10 @@ func runSignalWrapper(cmd *Command) error { }() notify := func() {} - if cmd.healthCheck { + if cmd.conf.HealthCheck { needsHTTPServer = true cmd.logger.Infof("Starting health check server at %s", - net.JoinHostPort(cmd.httpAddress, cmd.httpPort)) + net.JoinHostPort(cmd.conf.HTTPAddress, cmd.conf.HTTPPort)) hc := healthcheck.NewCheck(p, cmd.logger) mux.HandleFunc("/startup", hc.HandleStartup) mux.HandleFunc("/readiness", hc.HandleReadiness) @@ -718,7 +685,7 @@ func runSignalWrapper(cmd *Command) error { // Start the HTTP server if anything requiring HTTP is specified. if needsHTTPServer { server := &http.Server{ - Addr: net.JoinHostPort(cmd.httpAddress, cmd.httpPort), + Addr: net.JoinHostPort(cmd.conf.HTTPAddress, cmd.conf.HTTPPort), Handler: mux, } // Start the HTTP server. diff --git a/cmd/root_test.go b/cmd/root_test.go index e90ad692..19b6f58c 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -72,6 +72,15 @@ func withDefaults(c *proxy.Config) *proxy.Config { if c.FUSETempDir == "" { c.FUSETempDir = filepath.Join(os.TempDir(), "alloydb-tmp") } + if c.HTTPAddress == "" { + c.HTTPAddress = "localhost" + } + if c.HTTPPort == "" { + c.HTTPPort = "9090" + } + if c.TelemetryTracingSampleRate == 0 { + c.TelemetryTracingSampleRate = 10_000 + } if c.APIEndpointURL == "" { c.APIEndpointURL = "https://alloydb.googleapis.com/v1beta" } @@ -281,14 +290,10 @@ func TestNewCommandArguments(t *testing.T) { { desc: "using the impersonate service account flag", args: []string{"--impersonate-service-account", - "sv1@developer.gserviceaccount.com,sv2@developer.gserviceaccount.com,sv3@developer.gserviceaccount.com", + "sv1@developer.gserviceaccount.com", "projects/proj/locations/region/clusters/clust/instances/inst"}, want: withDefaults(&proxy.Config{ - ImpersonateTarget: "sv1@developer.gserviceaccount.com", - ImpersonateDelegates: []string{ - "sv3@developer.gserviceaccount.com", - "sv2@developer.gserviceaccount.com", - }, + ImpersonationChain: "sv1@developer.gserviceaccount.com", }), }, } @@ -307,113 +312,6 @@ func TestNewCommandArguments(t *testing.T) { } } -func TestNewCommandWithEnvironmentConfigPrivateFields(t *testing.T) { - tcs := []struct { - desc string - envName string - envValue string - isValid func(cmd *Command) bool - }{ - { - desc: "using the disable traces envvar", - envName: "ALLOYDB_PROXY_DISABLE_TRACES", - envValue: "true", - isValid: func(cmd *Command) bool { - return cmd.disableTraces == true - }, - }, - { - desc: "using the telemetry sample rate envvar", - envName: "ALLOYDB_PROXY_TELEMETRY_SAMPLE_RATE", - envValue: "500", - isValid: func(cmd *Command) bool { - return cmd.telemetryTracingSampleRate == 500 - }, - }, - { - desc: "using the disable metrics envvar", - envName: "ALLOYDB_PROXY_DISABLE_METRICS", - envValue: "true", - isValid: func(cmd *Command) bool { - return cmd.disableMetrics == true - }, - }, - { - desc: "using the telemetry project envvar", - envName: "ALLOYDB_PROXY_TELEMETRY_PROJECT", - envValue: "mycoolproject", - isValid: func(cmd *Command) bool { - return cmd.telemetryProject == "mycoolproject" - }, - }, - { - desc: "using the telemetry prefix envvar", - envName: "ALLOYDB_PROXY_TELEMETRY_PREFIX", - envValue: "myprefix", - isValid: func(cmd *Command) bool { - return cmd.telemetryPrefix == "myprefix" - }, - }, - { - desc: "using the prometheus envvar", - envName: "ALLOYDB_PROXY_PROMETHEUS", - envValue: "true", - isValid: func(cmd *Command) bool { - return cmd.prometheus == true - }, - }, - { - desc: "using the prometheus namespace envvar", - envName: "ALLOYDB_PROXY_PROMETHEUS_NAMESPACE", - envValue: "myns", - isValid: func(cmd *Command) bool { - return cmd.prometheusNamespace == "myns" - }, - }, - { - desc: "using the health check envvar", - envName: "ALLOYDB_PROXY_HEALTH_CHECK", - envValue: "true", - isValid: func(cmd *Command) bool { - return cmd.healthCheck == true - }, - }, - { - desc: "using the http address envvar", - envName: "ALLOYDB_PROXY_HTTP_ADDRESS", - envValue: "0.0.0.0", - isValid: func(cmd *Command) bool { - return cmd.httpAddress == "0.0.0.0" - }, - }, - { - desc: "using the http port envvar", - envName: "ALLOYDB_PROXY_HTTP_PORT", - envValue: "5555", - isValid: func(cmd *Command) bool { - return cmd.httpPort == "5555" - }, - }, - } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - os.Setenv(tc.envName, tc.envValue) - defer os.Unsetenv(tc.envName) - - c, err := invokeProxyCommand([]string{ - "projects/proj/locations/region/clusters/clust/instances/inst", - }) - if err != nil { - t.Fatalf("want error = nil, got = %v", err) - } - - if !tc.isValid(c) { - t.Fatal("want valid, got invalid") - } - }) - } -} - func TestNewCommandWithEnvironmentConfig(t *testing.T) { tcs := []struct { desc string @@ -512,13 +410,89 @@ func TestNewCommandWithEnvironmentConfig(t *testing.T) { { desc: "using the imopersonate service accounn envvar", envName: "ALLOYDB_PROXY_IMPERSONATE_SERVICE_ACCOUNT", - envValue: "sv1@developer.gserviceaccount.com,sv2@developer.gserviceaccount.com,sv3@developer.gserviceaccount.com", + envValue: "sv1@developer.gserviceaccount.com", + want: withDefaults(&proxy.Config{ + ImpersonationChain: "sv1@developer.gserviceaccount.com", + }), + }, + { + desc: "using the disable traces envvar", + envName: "ALLOYDB_PROXY_DISABLE_TRACES", + envValue: "true", + want: withDefaults(&proxy.Config{ + DisableTraces: true, + }), + }, + { + desc: "using the telemetry sample rate envvar", + envName: "ALLOYDB_PROXY_TELEMETRY_SAMPLE_RATE", + envValue: "500", + want: withDefaults(&proxy.Config{ + TelemetryTracingSampleRate: 500, + }), + }, + { + desc: "using the disable metrics envvar", + envName: "ALLOYDB_PROXY_DISABLE_METRICS", + envValue: "true", + want: withDefaults(&proxy.Config{ + DisableMetrics: true, + }), + }, + { + desc: "using the telemetry project envvar", + envName: "ALLOYDB_PROXY_TELEMETRY_PROJECT", + envValue: "mycoolproject", + want: withDefaults(&proxy.Config{ + TelemetryProject: "mycoolproject", + }), + }, + { + desc: "using the telemetry prefix envvar", + envName: "ALLOYDB_PROXY_TELEMETRY_PREFIX", + envValue: "myprefix", + want: withDefaults(&proxy.Config{ + TelemetryPrefix: "myprefix", + }), + }, + { + desc: "using the prometheus envvar", + envName: "ALLOYDB_PROXY_PROMETHEUS", + envValue: "true", + want: withDefaults(&proxy.Config{ + Prometheus: true, + }), + }, + { + desc: "using the prometheus namespace envvar", + envName: "ALLOYDB_PROXY_PROMETHEUS_NAMESPACE", + envValue: "myns", + want: withDefaults(&proxy.Config{ + PrometheusNamespace: "myns", + }), + }, + { + desc: "using the health check envvar", + envName: "ALLOYDB_PROXY_HEALTH_CHECK", + envValue: "true", + want: withDefaults(&proxy.Config{ + HealthCheck: true, + }), + }, + { + desc: "using the http address envvar", + envName: "ALLOYDB_PROXY_HTTP_ADDRESS", + envValue: "0.0.0.0", + want: withDefaults(&proxy.Config{ + HTTPAddress: "0.0.0.0", + }), + }, + { + desc: "using the http port envvar", + envName: "ALLOYDB_PROXY_HTTP_PORT", + envValue: "5555", want: withDefaults(&proxy.Config{ - ImpersonateTarget: "sv1@developer.gserviceaccount.com", - ImpersonateDelegates: []string{ - "sv3@developer.gserviceaccount.com", - "sv2@developer.gserviceaccount.com", - }, + HTTPPort: "5555", }), }, } diff --git a/internal/proxy/internal_test.go b/internal/proxy/internal_test.go new file mode 100644 index 00000000..8f2ff2fb --- /dev/null +++ b/internal/proxy/internal_test.go @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "testing" + "unsafe" + + "github.com/google/go-cmp/cmp" +) + +func TestClientUsesSyncAtomicAlignment(t *testing.T) { + // The sync/atomic pkg has a bug that requires the developer to guarantee + // 64-bit alignment when using 64-bit functions on 32-bit systems. + c := &Client{} //nolint:staticcheck + + if a := unsafe.Offsetof(c.connCount); a%64 != 0 { + t.Errorf("Client.connCount is not 64-bit aligned: want 0, got %v", a) + } +} + +func TestParseImpersonationChain(t *testing.T) { + tcs := []struct { + desc string + in string + wantTarget string + wantChain []string + }{ + { + desc: "when there is only a target", + in: "sv1@developer.gserviceaccount.com", + wantTarget: "sv1@developer.gserviceaccount.com", + }, + { + desc: "when there are delegates", + in: "sv1@developer.gserviceaccount.com,sv2@developer.gserviceaccount.com,sv3@developer.gserviceaccount.com", + wantTarget: "sv1@developer.gserviceaccount.com", + wantChain: []string{ + "sv3@developer.gserviceaccount.com", + "sv2@developer.gserviceaccount.com", + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + gotTarget, gotChain := parseImpersonationChain(tc.in) + if gotTarget != tc.wantTarget { + t.Fatalf("target: want = %v, got = %v", tc.wantTarget, gotTarget) + } + if !equalSlice(tc.wantChain, gotChain) { + t.Fatalf("want chain != got chain: %v", cmp.Diff(tc.wantChain, gotChain)) + } + }) + } +} + +func equalSlice[T comparable](x, y []T) bool { + if len(x) != len(y) { + return false + } + for i := 0; i < len(x); i++ { + if x[i] != y[i] { + return false + } + } + return true +} diff --git a/internal/proxy/other_test.go b/internal/proxy/other_test.go index 02957f9b..bd951076 100644 --- a/internal/proxy/other_test.go +++ b/internal/proxy/other_test.go @@ -16,19 +16,8 @@ package proxy import ( "testing" - "unsafe" ) -func TestClientUsesSyncAtomicAlignment(t *testing.T) { - // The sync/atomic pkg has a bug that requires the developer to guarantee - // 64-bit alignment when using 64-bit functions on 32-bit systems. - c := &Client{} //nolint:staticcheck - - if a := unsafe.Offsetof(c.connCount); a%64 != 0 { - t.Errorf("Client.connCount is not 64-bit aligned: want 0, got %v", a) - } -} - func TestUnixSocketDir(t *testing.T) { tcs := []struct { desc string diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index a475f10d..4a41e5cd 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -107,24 +107,71 @@ type Config struct { // regardless of any open connections. WaitOnClose time.Duration - // ImpersonateTarget is the service account to impersonate. The IAM - // principal doing the impersonation must have the - // roles/iam.serviceAccountTokenCreator role. - ImpersonateTarget string - // ImpersonateDelegates are the intermediate service accounts through which - // the impersonation is achieved. Each delegate must have the - // roles/iam.serviceAccountTokenCreator role. - ImpersonateDelegates []string + // ImpersonationChain is a comma separated list of one or more service + // accounts. The first entry in the chain is the impersonation target. Any + // additional service accounts after the target are delegates. The + // roles/iam.serviceAccountTokenCreator must be configured for each account + // that will be impersonated. + ImpersonationChain string // StructuredLogs sets all output to use JSON in the LogEntry format. // See https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry StructuredLogs bool + + // Quiet controls whether only error messages are logged. + Quiet bool + + // TelemetryProject enables sending metrics and traces to the specified project. + TelemetryProject string + // TelemetryPrefix sets a prefix for all emitted metrics. + TelemetryPrefix string + // TelemetryTracingSampleRate sets the rate at which traces are + // samples. A higher value means fewer traces. + TelemetryTracingSampleRate int + // DisableTraces disables tracing when TelemetryProject is set. + DisableTraces bool + // DisableMetrics disables metrics when TelemetryProject is set. + DisableMetrics bool + + // Prometheus enables a Prometheus endpoint served at the address and + // port specified by HTTPAddress and HTTPPort. + Prometheus bool + // PrometheusNamespace configures the namespace underwhich metrics are written. + PrometheusNamespace string + + // HealthCheck enables a health check server. It's address and port are + // specified by HTTPAddress and HTTPPort. + HealthCheck bool + + // HTTPAddress sets the address for the health check and prometheus server. + HTTPAddress string + // HTTPPort sets the port for the health check and prometheus server. + HTTPPort string + + // OtherUserAgents is a list of space separate user agents that will be + // appended to the default user agent. + OtherUserAgents string +} + +func parseImpersonationChain(chain string) (string, []string) { + accts := strings.Split(chain, ",") + target := accts[0] + // Assign delegates if the chain is more than one account. Delegation + // goes from last back towards target, e.g., With sa1,sa2,sa3, sa3 + // delegates to sa2, which impersonates the target sa1. + var delegates []string + if l := len(accts); l > 1 { + for i := l - 1; i > 0; i-- { + delegates = append(delegates, accts[i]) + } + } + return target, delegates } func credentialsOpt(c Config, l alloydb.Logger) (alloydbconn.Option, error) { // If service account impersonation is configured, set up an impersonated // credentials token source. - if c.ImpersonateTarget != "" { + if c.ImpersonationChain != "" { var iopts []option.ClientOption switch { case c.Token != "": @@ -148,11 +195,12 @@ func credentialsOpt(c Config, l alloydb.Logger) (alloydbconn.Option, error) { default: l.Infof("Impersonating service account with Application Default Credentials") } + target, delegates := parseImpersonationChain(c.ImpersonationChain) ts, err := impersonate.CredentialsTokenSource( context.Background(), impersonate.CredentialsConfig{ - TargetPrincipal: c.ImpersonateTarget, - Delegates: c.ImpersonateDelegates, + TargetPrincipal: target, + Delegates: delegates, Scopes: []string{sqladmin.SqlserviceAdminScope}, }, iopts...,