Skip to content

Commit

Permalink
fix: don't depend on downstream in readiness check (#641)
Browse files Browse the repository at this point in the history
The readiness probe should not depend on downstream connections. Doing
so can cause unwanted downtime. See [1]. This commit removes all dialing
of downstream database servers as a result.

In addition, after the Proxy receives a SIGTERM or SIGINT, the readiness
check will now report an unhealthy status to ensure it is removed from
circulation before shutdown completes.

[1]: https://github.com/zegl/kube-score/blob/master/README_PROBES.md#readinessprobe
  • Loading branch information
enocom committed May 2, 2024
1 parent eb4435b commit 3a7c789
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 220 deletions.
10 changes: 7 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -1066,8 +1066,10 @@ func runSignalWrapper(cmd *Command) (err error) {
var (
needsHTTPServer bool
mux = http.NewServeMux()
notify = func() {}
notifyStarted = func() {}
notifyStopped = func() {}
)

if cmd.conf.Prometheus {
needsHTTPServer = true
e, err := prometheus.NewExporter(prometheus.Options{
Expand All @@ -1087,8 +1089,10 @@ func runSignalWrapper(cmd *Command) (err error) {
mux.HandleFunc("/startup", hc.HandleStartup)
mux.HandleFunc("/readiness", hc.HandleReadiness)
mux.HandleFunc("/liveness", hc.HandleLiveness)
notify = hc.NotifyStarted
notifyStarted = hc.NotifyStarted
notifyStopped = hc.NotifyStopped
}
defer notifyStopped()
// Start the HTTP server if anything requiring HTTP is specified.
if needsHTTPServer {
go startHTTPServer(
Expand Down Expand Up @@ -1131,7 +1135,7 @@ func runSignalWrapper(cmd *Command) (err error) {
)
}

go func() { shutdownCh <- p.Serve(ctx, notify) }()
go func() { shutdownCh <- p.Serve(ctx, notifyStarted) }()

err = <-shutdownCh
switch {
Expand Down
111 changes: 35 additions & 76 deletions internal/healthcheck/healthcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package healthcheck

import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"sync"

"github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb"
Expand All @@ -31,25 +29,35 @@ import (
// Check provides HTTP handlers for use as healthchecks typically in a
// Kubernetes context.
type Check struct {
once *sync.Once
started chan struct{}
proxy *proxy.Client
logger alloydb.Logger
startedOnce *sync.Once
started chan struct{}
stoppedOnce *sync.Once
stopped chan struct{}
proxy *proxy.Client
logger alloydb.Logger
}

// NewCheck is the initializer for Check.
func NewCheck(p *proxy.Client, l alloydb.Logger) *Check {
return &Check{
once: &sync.Once{},
started: make(chan struct{}),
proxy: p,
logger: l,
startedOnce: &sync.Once{},
started: make(chan struct{}),
stoppedOnce: &sync.Once{},
stopped: make(chan struct{}),
proxy: p,
logger: l,
}
}

// NotifyStarted notifies the check that the proxy has started up successfully.
func (c *Check) NotifyStarted() {
c.once.Do(func() { close(c.started) })
c.startedOnce.Do(func() { close(c.started) })
}

// NotifyStopped notifies the check that the proxy has initiated its shutdown
// sequence.
func (c *Check) NotifyStopped() {
c.stoppedOnce.Do(func() { close(c.stopped) })
}

// HandleStartup reports whether the Check has been notified of startup.
Expand All @@ -64,86 +72,37 @@ func (c *Check) HandleStartup(w http.ResponseWriter, _ *http.Request) {
}
}

var errNotStarted = errors.New("proxy is not started")
var (
errNotStarted = errors.New("proxy is not started")
errStopped = errors.New("proxy has stopped")
)

// HandleReadiness ensures the Check has been notified of successful startup,
// that the proxy has not reached maximum connections, and that all connections
// are healthy.
func (c *Check) HandleReadiness(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()

// that the proxy has not reached maximum connections, and that the Proxy has
// not started shutting down.
func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) {
select {
case <-c.started:
// Proxy has started.
default:
c.logger.Errorf("[Health Check] Readiness failed: %v", errNotStarted)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(errNotStarted.Error()))
return
}

if open, max := c.proxy.ConnCount(); max > 0 && open == max {
err := fmt.Errorf("max connections reached (open = %v, max = %v)", open, max)
c.logger.Errorf("[Health Check] Readiness failed: %v", err)
select {
case <-c.stopped:
c.logger.Errorf("[Health Check] Readiness failed: %v", errStopped)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(err.Error()))
w.Write([]byte(errStopped.Error()))
return
default:
// Proxy has not stopped.
}

var minReady *int
q := req.URL.Query()
if v := q.Get("min-ready"); v != "" {
n, err := strconv.Atoi(v)
if err != nil {
c.logger.Errorf("[Health Check] min-ready must be a valid integer, got = %q", v)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "min-query must be a valid integer, got = %q", v)
return
}
if n <= 0 {
c.logger.Errorf("[Health Check] min-ready %q must be greater than zero", v)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "min-query must be greater than zero", v)
return
}
minReady = &n
}

total, err := c.proxy.CheckConnections(ctx)

switch {
case minReady != nil && *minReady > total:
// When min ready is set and exceeds total instances, 400 status.
mErr := fmt.Errorf(
"min-ready (%v) must be less than or equal to the number of registered instances (%v)",
*minReady, total,
)
c.logger.Errorf("[Health Check] Readiness failed: %v", mErr)

w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(mErr.Error()))
return
case err != nil && minReady != nil:
// When there's an error and min ready is set, AND min ready instances
// are not ready, 503 status.
c.logger.Errorf("[Health Check] Readiness failed: %v", err)

mErr, ok := err.(proxy.MultiErr)
if !ok {
// If the err is not a MultiErr, just return it as is.
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(err.Error()))
return
}

areReady := total - len(mErr)
if areReady < *minReady {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(err.Error()))
return
}
case err != nil:
// When there's just an error without min-ready: 503 status.
if open, max := c.proxy.ConnCount(); max > 0 && open == max {
err := fmt.Errorf("max connections reached (open = %v, max = %v)", open, max)
c.logger.Errorf("[Health Check] Readiness failed: %v", err)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(err.Error()))
Expand Down
162 changes: 21 additions & 141 deletions internal/healthcheck/healthcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package healthcheck_test

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -25,7 +24,6 @@ import (
"net/url"
"os"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -73,29 +71,6 @@ func (*fakeDialer) Close() error {
return nil
}

type flakeyDialer struct {
dialCount uint64
fakeDialer
}

// Dial fails on odd calls and succeeds on even calls.
func (f *flakeyDialer) Dial(_ context.Context, _ string, _ ...alloydbconn.DialOption) (net.Conn, error) {
c := atomic.AddUint64(&f.dialCount, 1)
if c%2 == 0 {
conn, _ := net.Pipe()
return conn, nil
}
return nil, errors.New("flakey dialer fails on odd calls")
}

type errorDialer struct {
fakeDialer
}

func (*errorDialer) Dial(_ context.Context, _ string, _ ...alloydbconn.DialOption) (net.Conn, error) {
return nil, errors.New("errorDialer always errors")
}

func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer, instances []proxy.InstanceConnConfig) *proxy.Client {
c := &proxy.Config{
Addr: proxyHost,
Expand All @@ -116,13 +91,6 @@ func newTestProxyWithMaxConns(t *testing.T, maxConns uint64) *proxy.Client {
})
}

func newTestProxyWithDialer(t *testing.T, d alloydb.Dialer) *proxy.Client {
return newProxyWithParams(t, 0, d, []proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst"},
})

}

func newTestProxy(t *testing.T) *proxy.Client {
return newProxyWithParams(t, 0, &fakeDialer{}, []proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst"},
Expand Down Expand Up @@ -187,6 +155,27 @@ func TestHandleReadinessWhenNotNotified(t *testing.T) {
}
}

func TestHandleReadinessWhenStopped(t *testing.T) {
p := newTestProxy(t)
defer func() {
if err := p.Close(); err != nil {
t.Logf("failed to close proxy client: %v", err)
}
}()
check := healthcheck.NewCheck(p, logger)

check.NotifyStarted() // The Proxy has started.
check.NotifyStopped() // And now the Proxy is shutting down.

rec := httptest.NewRecorder()
check.HandleReadiness(rec, &http.Request{URL: &url.URL{}})

resp := rec.Result()
if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want {
t.Fatalf("want = %v, got = %v", want, got)
}
}

func TestHandleReadinessForMaxConns(t *testing.T) {
p := newTestProxyWithMaxConns(t, 1)
defer func() {
Expand Down Expand Up @@ -235,112 +224,3 @@ func TestHandleReadinessForMaxConns(t *testing.T) {
t.Fatalf("want max connections error, got = %v", string(body))
}
}

func TestHandleReadinessWithConnectionProblems(t *testing.T) {
p := newTestProxyWithDialer(t, &errorDialer{}) // error dialer will error on dial
defer func() {
if err := p.Close(); err != nil {
t.Logf("failed to close proxy client: %v", err)
}
}()
check := healthcheck.NewCheck(p, logger)
check.NotifyStarted()

rec := httptest.NewRecorder()
check.HandleReadiness(rec, &http.Request{URL: &url.URL{}})

resp := rec.Result()
if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want {
t.Fatalf("want = %v, got = %v", want, got)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
if want := "errorDialer"; !strings.Contains(string(body), want) {
t.Fatalf("want substring with = %q, got = %v", want, string(body))
}
}

func TestReadinessWithMinReady(t *testing.T) {
tcs := []struct {
desc string
minReady string
wantStatus int
dialer alloydb.Dialer
}{
{
desc: "when min ready is zero",
minReady: "0",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when min ready is less than zero",
minReady: "-1",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when only one instance must be ready",
minReady: "1",
wantStatus: http.StatusOK,
dialer: &flakeyDialer{}, // fails on first call, succeeds on second
},
{
desc: "when all instances must be ready",
minReady: "2",
wantStatus: http.StatusServiceUnavailable,
dialer: &errorDialer{},
},
{
desc: "when min ready is greater than the number of instances",
minReady: "3",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when min ready is bogus",
minReady: "bogus",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when min ready is not set",
minReady: "",
wantStatus: http.StatusOK,
dialer: &fakeDialer{},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
p := newProxyWithParams(t, 0,
tc.dialer,
[]proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst-1"},
{Name: "projects/proj/locations/region/clusters/clust/instances/inst-2"},
},
)
defer func() {
if err := p.Close(); err != nil {
t.Logf("failed to close proxy client: %v", err)
}
}()

check := healthcheck.NewCheck(p, logger)
check.NotifyStarted()
u, err := url.Parse(fmt.Sprintf("/readiness?min-ready=%s", tc.minReady))
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
check.HandleReadiness(rec, &http.Request{URL: u})

resp := rec.Result()
if got, want := resp.StatusCode, tc.wantStatus; got != want {
t.Fatalf("want = %v, got = %v", want, got)
}
})
}
}

0 comments on commit 3a7c789

Please sign in to comment.