Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for min ready instances #229

Merged
merged 1 commit into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,8 @@ func runSignalWrapper(cmd *Command) error {
notify := func() {}
if cmd.healthCheck {
needsHTTPServer = true
cmd.logger.Infof("Starting health check server at %s",
net.JoinHostPort(cmd.httpAddress, cmd.httpPort))
hc := healthcheck.NewCheck(p, cmd.logger)
mux.HandleFunc("/startup", hc.HandleStartup)
mux.HandleFunc("/readiness", hc.HandleReadiness)
Expand All @@ -708,7 +710,7 @@ func runSignalWrapper(cmd *Command) error {
// Start the HTTP server if anything requiring HTTP is specified.
if needsHTTPServer {
server := &http.Server{
Addr: fmt.Sprintf("%s:%s", cmd.httpAddress, cmd.httpPort),
Addr: net.JoinHostPort(cmd.httpAddress, cmd.httpPort),
Handler: mux,
}
// Start the HTTP server.
Expand Down
60 changes: 57 additions & 3 deletions internal/healthcheck/healthcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
"sync"

"github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb"
Expand Down Expand Up @@ -68,7 +69,7 @@ var errNotStarted = errors.New("proxy is not started")
// HandleReadiness ensures the Check has been notified of successful startup,
// that the proxy has not reached maximum connections, and that all connections
// are healthy.
func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) {
func (c *Check) HandleReadiness(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -89,14 +90,67 @@ func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) {
return
}

err := c.proxy.CheckConnections(ctx)
if err != nil {
var minReady *int
q := req.URL.Query()
if v := q.Get("min-ready"); v != "" {
n, err := strconv.Atoi(v)
if err != nil {
c.logger.Errorf("[Health Check] min-ready must be a valid integer, got = %q", v)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "min-query must be a valid integer, got = %q", v)
return
}
if n <= 0 {
c.logger.Errorf("[Health Check] min-ready %q must be greater than zero", v)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, "min-query must be greater than zero", v)
return
}
minReady = &n
}

total, err := c.proxy.CheckConnections(ctx)

switch {
case minReady != nil && *minReady > total:
// When min ready is set and exceeds total instances, 400 status.
mErr := fmt.Errorf(
"min-ready (%v) must be less than or equal to the number of registered instances (%v)",
*minReady, total,
)
c.logger.Errorf("[Health Check] Readiness failed: %v", mErr)

w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(mErr.Error()))
return
case err != nil && minReady != nil:
// When there's an error and min ready is set, AND min ready instances
// are not ready, 503 status.
c.logger.Errorf("[Health Check] Readiness failed: %v", err)

mErr, ok := err.(proxy.MultiErr)
if !ok {
// If the err is not a MultiErr, just return it as is.
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(err.Error()))
return
}

areReady := total - len(mErr)
if areReady < *minReady {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(err.Error()))
return
}
case err != nil:
// When there's just an error without min-ready: 503 status.
c.logger.Errorf("[Health Check] Readiness failed: %v", err)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(err.Error()))
return
}

// No error cases apply, 200 status.
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}
Expand Down
133 changes: 119 additions & 14 deletions internal/healthcheck/healthcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -71,6 +73,21 @@ func (*fakeDialer) Close() error {
return nil
}

type flakeyDialer struct {
dialCount uint64
fakeDialer
}

// Dial fails on odd calls and succeeds on even calls.
func (f *flakeyDialer) Dial(_ context.Context, _ string, _ ...alloydbconn.DialOption) (net.Conn, error) {
c := atomic.AddUint64(&f.dialCount, 1)
if c%2 == 0 {
conn, _ := net.Pipe()
return conn, nil
}
return nil, errors.New("flakey dialer fails on odd calls")
}

type errorDialer struct {
fakeDialer
}
Expand All @@ -79,13 +96,11 @@ func (*errorDialer) Dial(_ context.Context, _ string, _ ...alloydbconn.DialOptio
return nil, errors.New("errorDialer always errors")
}

func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer) *proxy.Client {
func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer, instances []proxy.InstanceConnConfig) *proxy.Client {
c := &proxy.Config{
Addr: proxyHost,
Port: proxyPort,
Instances: []proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst"},
},
Addr: proxyHost,
Port: proxyPort,
Instances: instances,
MaxConnections: maxConns,
}
p, err := proxy.NewClient(context.Background(), dialer, logger, c)
Expand All @@ -96,15 +111,22 @@ func newProxyWithParams(t *testing.T, maxConns uint64, dialer alloydb.Dialer) *p
}

func newTestProxyWithMaxConns(t *testing.T, maxConns uint64) *proxy.Client {
return newProxyWithParams(t, maxConns, &fakeDialer{})
return newProxyWithParams(t, maxConns, &fakeDialer{}, []proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst"},
})
}

func newTestProxyWithDialer(t *testing.T, d alloydb.Dialer) *proxy.Client {
return newProxyWithParams(t, 0, d)
return newProxyWithParams(t, 0, d, []proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst"},
})

}

func newTestProxy(t *testing.T) *proxy.Client {
return newProxyWithParams(t, 0, &fakeDialer{})
return newProxyWithParams(t, 0, &fakeDialer{}, []proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst"},
})
}

func TestHandleStartupWhenNotNotified(t *testing.T) {
Expand All @@ -117,7 +139,7 @@ func TestHandleStartupWhenNotNotified(t *testing.T) {
check := healthcheck.NewCheck(p, logger)

rec := httptest.NewRecorder()
check.HandleStartup(rec, &http.Request{})
check.HandleStartup(rec, &http.Request{URL: &url.URL{}})

// Startup is not complete because the Check has not been notified of the
// proxy's startup.
Expand All @@ -139,7 +161,7 @@ func TestHandleStartupWhenNotified(t *testing.T) {
check.NotifyStarted()

rec := httptest.NewRecorder()
check.HandleStartup(rec, &http.Request{})
check.HandleStartup(rec, &http.Request{URL: &url.URL{}})

resp := rec.Result()
if got, want := resp.StatusCode, http.StatusOK; got != want {
Expand All @@ -157,7 +179,7 @@ func TestHandleReadinessWhenNotNotified(t *testing.T) {
check := healthcheck.NewCheck(p, logger)

rec := httptest.NewRecorder()
check.HandleReadiness(rec, &http.Request{})
check.HandleReadiness(rec, &http.Request{URL: &url.URL{}})

resp := rec.Result()
if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want {
Expand Down Expand Up @@ -193,13 +215,14 @@ func TestHandleReadinessForMaxConns(t *testing.T) {
waitForConnect := func(t *testing.T, wantCode int) *http.Response {
for i := 0; i < 10; i++ {
rec := httptest.NewRecorder()
check.HandleReadiness(rec, &http.Request{})
check.HandleReadiness(rec, &http.Request{URL: &url.URL{}})
resp := rec.Result()
if resp.StatusCode == wantCode {
return resp
}
time.Sleep(time.Second)
}
t.Fatalf("failed to receive status code = %v", wantCode)
return nil
}
resp := waitForConnect(t, http.StatusServiceUnavailable)
Expand All @@ -224,7 +247,7 @@ func TestHandleReadinessWithConnectionProblems(t *testing.T) {
check.NotifyStarted()

rec := httptest.NewRecorder()
check.HandleReadiness(rec, &http.Request{})
check.HandleReadiness(rec, &http.Request{URL: &url.URL{}})

resp := rec.Result()
if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want {
Expand All @@ -239,3 +262,85 @@ func TestHandleReadinessWithConnectionProblems(t *testing.T) {
t.Fatalf("want substring with = %q, got = %v", want, string(body))
}
}

func TestReadinessWithMinReady(t *testing.T) {
tcs := []struct {
desc string
minReady string
wantStatus int
dialer alloydb.Dialer
}{
{
desc: "when min ready is zero",
minReady: "0",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when min ready is less than zero",
minReady: "-1",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when only one instance must be ready",
minReady: "1",
wantStatus: http.StatusOK,
dialer: &flakeyDialer{}, // fails on first call, succeeds on second
},
{
desc: "when all instances must be ready",
minReady: "2",
wantStatus: http.StatusServiceUnavailable,
dialer: &errorDialer{},
},
{
desc: "when min ready is greater than the number of instances",
minReady: "3",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when min ready is bogus",
minReady: "bogus",
wantStatus: http.StatusBadRequest,
dialer: &fakeDialer{},
},
{
desc: "when min ready is not set",
minReady: "",
wantStatus: http.StatusOK,
dialer: &fakeDialer{},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
p := newProxyWithParams(t, 0,
tc.dialer,
[]proxy.InstanceConnConfig{
{Name: "projects/proj/locations/region/clusters/clust/instances/inst-1"},
{Name: "projects/proj/locations/region/clusters/clust/instances/inst-2"},
},
)
defer func() {
if err := p.Close(); err != nil {
t.Logf("failed to close proxy client: %v", err)
}
}()

check := healthcheck.NewCheck(p, logger)
check.NotifyStarted()
u, err := url.Parse(fmt.Sprintf("/readiness?min-ready=%s", tc.minReady))
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
check.HandleReadiness(rec, &http.Request{URL: u})

resp := rec.Result()
if got, want := resp.StatusCode, tc.wantStatus; got != want {
t.Fatalf("want = %v, got = %v", want, got)
}
})
}
}
6 changes: 5 additions & 1 deletion internal/proxy/fuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,13 @@ func TestFUSECheckConnections(t *testing.T) {
conn := tryDialUnix(t, postgresSocketPath(fuseDir, "proj.region.cluster.instance"))
defer conn.Close()

if err := c.CheckConnections(context.Background()); err != nil {
n, err := c.CheckConnections(context.Background())
if err != nil {
t.Fatalf("c.CheckConnections(): %v", err)
}
if want, got := 1, n; want != got {
t.Fatalf("CheckConnections number of connections: want = %v, got = %v", want, got)
}

// verify the dialer was invoked twice, once for connect, once for check
// connection
Expand Down
18 changes: 11 additions & 7 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ func NewClient(ctx context.Context, d alloydb.Dialer, l alloydb.Logger, conf *Co
return c, nil
}

// CheckConnections dials each registered instance and reports any errors that
// may have occurred.
func (c *Client) CheckConnections(ctx context.Context) error {
// CheckConnections dials each registered instance and reports the number of
// connections checked and any errors that may have occurred.
func (c *Client) CheckConnections(ctx context.Context) (int, error) {
var (
wg sync.WaitGroup
errCh = make(chan error, len(c.mnts))
Expand All @@ -394,25 +394,29 @@ func (c *Client) CheckConnections(ctx context.Context) error {
}
cErr := conn.Close()
if cErr != nil {
errCh <- fmt.Errorf("%v: %v", inst, cErr)
c.logger.Errorf(
"connection check failed to close connection for %v: %v",
inst, cErr,
)
}
}(m.inst)
}
wg.Wait()

var mErr MultiErr
for i := 0; i < len(c.mnts); i++ {
for i := 0; i < len(mnts); i++ {
select {
case err := <-errCh:
mErr = append(mErr, err)
default:
continue
}
}
mLen := len(mnts)
if len(mErr) > 0 {
return mErr
return mLen, mErr
}
return nil
return mLen, nil
}

// ConnCount returns the number of open connections and the maximum allowed
Expand Down
Loading