diff --git a/cmd/root.go b/cmd/root.go index 58c4fbf5..2ee7996c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -698,6 +698,8 @@ func runSignalWrapper(cmd *Command) error { notify := func() {} if cmd.healthCheck { needsHTTPServer = true + cmd.logger.Infof("Starting health check server at %s", + net.JoinHostPort(cmd.httpAddress, cmd.httpPort)) hc := healthcheck.NewCheck(p, cmd.logger) mux.HandleFunc("/startup", hc.HandleStartup) mux.HandleFunc("/readiness", hc.HandleReadiness) @@ -708,7 +710,7 @@ func runSignalWrapper(cmd *Command) error { // Start the HTTP server if anything requiring HTTP is specified. if needsHTTPServer { server := &http.Server{ - Addr: fmt.Sprintf("%s:%s", cmd.httpAddress, cmd.httpPort), + Addr: net.JoinHostPort(cmd.httpAddress, cmd.httpPort), Handler: mux, } // Start the HTTP server. diff --git a/internal/healthcheck/healthcheck.go b/internal/healthcheck/healthcheck.go index ecb15d7f..1fd7818d 100644 --- a/internal/healthcheck/healthcheck.go +++ b/internal/healthcheck/healthcheck.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "sync" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" @@ -68,7 +69,7 @@ var errNotStarted = errors.New("proxy is not started") // HandleReadiness ensures the Check has been notified of successful startup, // that the proxy has not reached maximum connections, and that all connections // are healthy. -func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) { +func (c *Check) HandleReadiness(w http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -89,14 +90,67 @@ func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) { return } - err := c.proxy.CheckConnections(ctx) - if err != nil { + var minReady *int + q := req.URL.Query() + if v := q.Get("min-ready"); v != "" { + n, err := strconv.Atoi(v) + if err != nil { + c.logger.Errorf("[Health Check] min-ready must be a valid integer, got = %q", v) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "min-query must be a valid integer, got = %q", v) + return + } + if n <= 0 { + c.logger.Errorf("[Health Check] min-ready %q must be greater than zero", v) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, "min-query must be greater than zero", v) + return + } + minReady = &n + } + + total, err := c.proxy.CheckConnections(ctx) + + switch { + case minReady != nil && *minReady > total: + // When min ready is set and exceeds total instances, 400 status. + mErr := fmt.Errorf( + "min-ready (%v) must be less than or equal to the number of registered instances (%v)", + *minReady, total, + ) + c.logger.Errorf("[Health Check] Readiness failed: %v", mErr) + + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(mErr.Error())) + return + case err != nil && minReady != nil: + // When there's an error and min ready is set, AND min ready instances + // are not ready, 503 status. + c.logger.Errorf("[Health Check] Readiness failed: %v", err) + + mErr, ok := err.(proxy.MultiErr) + if !ok { + // If the err is not a MultiErr, just return it as is. + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + + areReady := total - len(mErr) + if areReady < *minReady { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + case err != nil: + // When there's just an error without min-ready: 503 status. c.logger.Errorf("[Health Check] Readiness failed: %v", err) w.WriteHeader(http.StatusServiceUnavailable) w.Write([]byte(err.Error())) return } + // No error cases apply, 200 status. w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) } diff --git a/internal/healthcheck/healthcheck_test.go b/internal/healthcheck/healthcheck_test.go index adaba488..b6c174b0 100644 --- a/internal/healthcheck/healthcheck_test.go +++ b/internal/healthcheck/healthcheck_test.go @@ -22,8 +22,10 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "strings" + "sync/atomic" "testing" "time" @@ -71,6 +73,21 @@ func (*fakeDialer) Close() error { return nil } +type flakeyDialer struct { + dialCount uint64 + fakeDialer +} + +// Dial fails on odd calls and succeeds on even calls. +func (f *flakeyDialer) Dial(_ context.Context, _ string, _ ...alloydbconn.DialOption) (net.Conn, error) { + c := atomic.AddUint64(&f.dialCount, 1) + if c%2 == 0 { + conn, _ := net.Pipe() + return conn, nil + } + return nil, errors.New("flakey dialer fails on odd calls") +} + type errorDialer struct { fakeDialer } @@ -79,13 +96,11 @@ func (*errorDialer) Dial(_ context.Context, _ string, _ ...alloydbconn.DialOptio return nil, errors.New("errorDialer always errors") } -func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer) *proxy.Client { +func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer, instances []proxy.InstanceConnConfig) *proxy.Client { c := &proxy.Config{ - Addr: proxyHost, - Port: proxyPort, - Instances: []proxy.InstanceConnConfig{ - {Name: "projects/proj/locations/region/clusters/clust/instances/inst"}, - }, + Addr: proxyHost, + Port: proxyPort, + Instances: instances, MaxConnections: maxConns, } p, err := proxy.NewClient(context.Background(), dialer, logger, c) @@ -96,15 +111,22 @@ func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer) *p } func newTestProxyWithMaxConns(t *testing.T, maxConns uint64) *proxy.Client { - return newProxyWithParams(t, maxConns, &fakeDialer{}) + return newProxyWithParams(t, maxConns, &fakeDialer{}, []proxy.InstanceConnConfig{ + {Name: "projects/proj/locations/region/clusters/clust/instances/inst"}, + }) } func newTestProxyWithDialer(t *testing.T, d alloydb.Dialer) *proxy.Client { - return newProxyWithParams(t, 0, d) + return newProxyWithParams(t, 0, d, []proxy.InstanceConnConfig{ + {Name: "projects/proj/locations/region/clusters/clust/instances/inst"}, + }) + } func newTestProxy(t *testing.T) *proxy.Client { - return newProxyWithParams(t, 0, &fakeDialer{}) + return newProxyWithParams(t, 0, &fakeDialer{}, []proxy.InstanceConnConfig{ + {Name: "projects/proj/locations/region/clusters/clust/instances/inst"}, + }) } func TestHandleStartupWhenNotNotified(t *testing.T) { @@ -117,7 +139,7 @@ func TestHandleStartupWhenNotNotified(t *testing.T) { check := healthcheck.NewCheck(p, logger) rec := httptest.NewRecorder() - check.HandleStartup(rec, &http.Request{}) + check.HandleStartup(rec, &http.Request{URL: &url.URL{}}) // Startup is not complete because the Check has not been notified of the // proxy's startup. @@ -139,7 +161,7 @@ func TestHandleStartupWhenNotified(t *testing.T) { check.NotifyStarted() rec := httptest.NewRecorder() - check.HandleStartup(rec, &http.Request{}) + check.HandleStartup(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { @@ -157,7 +179,7 @@ func TestHandleReadinessWhenNotNotified(t *testing.T) { check := healthcheck.NewCheck(p, logger) rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{}) + check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { @@ -193,13 +215,14 @@ func TestHandleReadinessForMaxConns(t *testing.T) { waitForConnect := func(t *testing.T, wantCode int) *http.Response { for i := 0; i < 10; i++ { rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{}) + check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if resp.StatusCode == wantCode { return resp } time.Sleep(time.Second) } + t.Fatalf("failed to receive status code = %v", wantCode) return nil } resp := waitForConnect(t, http.StatusServiceUnavailable) @@ -224,7 +247,7 @@ func TestHandleReadinessWithConnectionProblems(t *testing.T) { check.NotifyStarted() rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{}) + check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { @@ -239,3 +262,85 @@ func TestHandleReadinessWithConnectionProblems(t *testing.T) { t.Fatalf("want substring with = %q, got = %v", want, string(body)) } } + +func TestReadinessWithMinReady(t *testing.T) { + tcs := []struct { + desc string + minReady string + wantStatus int + dialer alloydb.Dialer + }{ + { + desc: "when min ready is zero", + minReady: "0", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when min ready is less than zero", + minReady: "-1", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when only one instance must be ready", + minReady: "1", + wantStatus: http.StatusOK, + dialer: &flakeyDialer{}, // fails on first call, succeeds on second + }, + { + desc: "when all instances must be ready", + minReady: "2", + wantStatus: http.StatusServiceUnavailable, + dialer: &errorDialer{}, + }, + { + desc: "when min ready is greater than the number of instances", + minReady: "3", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when min ready is bogus", + minReady: "bogus", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when min ready is not set", + minReady: "", + wantStatus: http.StatusOK, + dialer: &fakeDialer{}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + p := newProxyWithParams(t, 0, + tc.dialer, + []proxy.InstanceConnConfig{ + {Name: "projects/proj/locations/region/clusters/clust/instances/inst-1"}, + {Name: "projects/proj/locations/region/clusters/clust/instances/inst-2"}, + }, + ) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + + check := healthcheck.NewCheck(p, logger) + check.NotifyStarted() + u, err := url.Parse(fmt.Sprintf("/readiness?min-ready=%s", tc.minReady)) + if err != nil { + t.Fatal(err) + } + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{URL: u}) + + resp := rec.Result() + if got, want := resp.StatusCode, tc.wantStatus; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } + }) + } +} diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go index de367f2e..2a0b5574 100644 --- a/internal/proxy/fuse_test.go +++ b/internal/proxy/fuse_test.go @@ -244,9 +244,13 @@ func TestFUSECheckConnections(t *testing.T) { conn := tryDialUnix(t, postgresSocketPath(fuseDir, "proj.region.cluster.instance")) defer conn.Close() - if err := c.CheckConnections(context.Background()); err != nil { + n, err := c.CheckConnections(context.Background()) + if err != nil { t.Fatalf("c.CheckConnections(): %v", err) } + if want, got := 1, n; want != got { + t.Fatalf("CheckConnections number of connections: want = %v, got = %v", want, got) + } // verify the dialer was invoked twice, once for connect, once for check // connection diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 00bbc288..a475f10d 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -371,9 +371,9 @@ func NewClient(ctx context.Context, d alloydb.Dialer, l alloydb.Logger, conf *Co return c, nil } -// CheckConnections dials each registered instance and reports any errors that -// may have occurred. -func (c *Client) CheckConnections(ctx context.Context) error { +// CheckConnections dials each registered instance and reports the number of +// connections checked and any errors that may have occurred. +func (c *Client) CheckConnections(ctx context.Context) (int, error) { var ( wg sync.WaitGroup errCh = make(chan error, len(c.mnts)) @@ -394,14 +394,17 @@ func (c *Client) CheckConnections(ctx context.Context) error { } cErr := conn.Close() if cErr != nil { - errCh <- fmt.Errorf("%v: %v", inst, cErr) + c.logger.Errorf( + "connection check failed to close connection for %v: %v", + inst, cErr, + ) } }(m.inst) } wg.Wait() var mErr MultiErr - for i := 0; i < len(c.mnts); i++ { + for i := 0; i < len(mnts); i++ { select { case err := <-errCh: mErr = append(mErr, err) @@ -409,10 +412,11 @@ func (c *Client) CheckConnections(ctx context.Context) error { continue } } + mLen := len(mnts) if len(mErr) > 0 { - return mErr + return mLen, mErr } - return nil + return mLen, nil } // ConnCount returns the number of open connections and the maximum allowed diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 1f93e2d7..802a4148 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -610,9 +610,13 @@ func TestCheckConnections(t *testing.T) { defer c.Close() go c.Serve(context.Background(), func() {}) - if err = c.CheckConnections(context.Background()); err != nil { + n, err := c.CheckConnections(context.Background()) + if err != nil { t.Fatalf("CheckConnections failed: %v", err) } + if want, got := len(in.Instances), n; want != got { + t.Fatalf("CheckConnections number of connections: want = %v, got = %v", want, got) + } if want, got := 1, d.dialAttempts(); want != got { t.Fatalf("dial attempts: want = %v, got = %v", want, got) @@ -634,8 +638,11 @@ func TestCheckConnections(t *testing.T) { defer c.Close() go c.Serve(context.Background(), func() {}) - err = c.CheckConnections(context.Background()) + n, err = c.CheckConnections(context.Background()) if err == nil { t.Fatal("CheckConnections should have failed, but did not") } + if want, got := len(in.Instances), n; want != got { + t.Fatalf("CheckConnections number of connections: want = %v, got = %v", want, got) + } }