Skip to content

Commit

Permalink
feat: add concrete errors to public API (#36)
Browse files Browse the repository at this point in the history
* feat: add concrete errors to public API

This commit adds four concrete errors to the public API and ensures that
all returned errors are either instances of those concrete types OR are
wrapped by more descriptive errors.

The four error types are:

1. ClientError: for all errors that are the result of the client making
   semantic error in their request

2. ServerError: for all the (uncommon) cases where the server returns
   data that is somehow and unexpectedly invalid.

3. APIError: for any errors that occurs when interacting with the SQL
   Admin API. These errors wrap the underlying
   google.golang.org/api/googleapi.Error types.

4. DialError: for any error that occurs when attempting to connect to a
   particular SQL instance.

By providing concrete errors or wrapping concrete errors, we allow our
clients to uses the errors.As API to possibly react differently to
errors coming out of the dialer.

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
  • Loading branch information
enocom and kurtisvg committed Aug 12, 2021
1 parent 5d54ca6 commit 7441b71
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 85 deletions.
9 changes: 5 additions & 4 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"sync"
"time"

"cloud.google.com/go/cloudsqlconn/errtypes"
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
"cloud.google.com/go/cloudsqlconn/internal/trace"
"golang.org/x/net/proxy"
Expand Down Expand Up @@ -153,22 +154,22 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
if err != nil {
// refresh the instance info in case it caused the connection failure
i.ForceRefresh()
return nil, err
return nil, errtypes.NewDialError("failed to dial", i.String(), err)
}
if c, ok := conn.(*net.TCPConn); ok {
if err := c.SetKeepAlive(true); err != nil {
return nil, fmt.Errorf("failed to set keep-alive: %v", err)
return nil, errtypes.NewDialError("failed to set keep-alive", i.String(), err)
}
if err := c.SetKeepAlivePeriod(cfg.tcpKeepAlive); err != nil {
return nil, fmt.Errorf("failed to set keep-alive period: %v", err)
return nil, errtypes.NewDialError("failed to set keep-alive period", i.String(), err)
}
}
tlsConn := tls.Client(conn, tlsCfg)
if err := tlsConn.Handshake(); err != nil {
// refresh the instance info in case it caused the handshake failure
i.ForceRefresh()
_ = tlsConn.Close() // best effort close attempt
return nil, fmt.Errorf("handshake failed: %w", err)
return nil, errtypes.NewDialError("handshake failed", i.String(), err)
}
return tlsConn, nil
}
Expand Down
41 changes: 16 additions & 25 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import (
"context"
"errors"
"io/ioutil"
"strings"
"testing"
"time"

"cloud.google.com/go/cloudsqlconn/errtypes"
"cloud.google.com/go/cloudsqlconn/internal/mock"
)

Expand Down Expand Up @@ -71,13 +71,6 @@ func TestDialerInstantiationErrors(t *testing.T) {
}
}

func errorContains(err error, want string) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), want)
}

func TestDialWithAdminAPIErrors(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(context.Background())
Expand All @@ -98,25 +91,24 @@ func TestDialWithAdminAPIErrors(t *testing.T) {
}
d.sqladmin = svc

// instance name is bad
_, err = d.Dial(context.Background(), "bad-instance-name")
if !errorContains(err, "invalid instance") {
t.Fatalf("expected Dial to fail with bad instance name, but it succeeded.")
var wantErr1 *errtypes.ConfigError
if !errors.As(err, &wantErr1) {
t.Fatalf("when instance name is invalid, want = %T, got = %v", wantErr1, err)
}

ctx, cancel := context.WithCancel(context.Background())
cancel()

// context is canceled
_, err = d.Dial(ctx, "my-project:my-region:my-instance")
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected Dial to fail with canceled context, but it succeeded.")
t.Fatalf("when context is canceled, want = %T, got = %v", context.Canceled, err)
}

// failed to retrieve metadata or ephemeral cert (not registered in the mock)
_, err = d.Dial(context.Background(), "my-project:my-region:my-instance")
if !errorContains(err, "fetch metadata failed") {
t.Fatalf("expected Dial to fail with missing metadata")
var wantErr2 *errtypes.RefreshError
if !errors.As(err, &wantErr2) {
t.Fatalf("when API call fails, want = %T, got = %v", wantErr2, err)
}
}

Expand All @@ -142,24 +134,23 @@ func TestDialWithConfigurationErrors(t *testing.T) {
}
}()

// when failing to find private IP for public-only instance
_, err = d.Dial(context.Background(), "my-project:my-region:my-instance", WithPrivateIP())
if !errorContains(err, "does not have IP of type") {
t.Fatalf("expected Dial to fail with missing metadata")
var wantErr1 *errtypes.ConfigError
if !errors.As(err, &wantErr1) {
t.Fatalf("when IP type is invalid, want = %T, got = %v", wantErr1, err)
}

// when Dialing TCP socket fails (no server proxy running)
_, err = d.Dial(context.Background(), "my-project:my-region:my-instance")
if !errorContains(err, "connection refused") {
t.Fatalf("expected Dial to fail with connection error")
var wantErr2 *errtypes.DialError
if !errors.As(err, &wantErr2) {
t.Fatalf("when server proxy socket is unavailable, want = %T, got = %v", wantErr2, err)
}

stop := mock.StartServerProxy(t, inst)
defer stop()

// when TLS handshake fails
_, err = d.Dial(context.Background(), "my-project:my-region:my-instance")
if !errorContains(err, "handshake failed") {
t.Fatalf("expected Dial to fail with connection error")
if !errors.As(err, &wantErr2) {
t.Fatalf("when TLS handshake fails, want = %T, got = %v", wantErr2, err)
}
}
17 changes: 17 additions & 0 deletions errtypes/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2021 Google LLC

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// https://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package errtypes provides a number of concrete types which are used by the
// cloudsqlconn package.
package errtypes // import "cloud.google.com/go/cloudsqlconn/errtypes"
92 changes: 92 additions & 0 deletions errtypes/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright 2021 Google LLC

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// https://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package errtypes

import "fmt"

type genericError struct {
Message string
ConnName string
}

func (e *genericError) Error() string {
return fmt.Sprintf("%v (connection name = %q)", e.Message, e.ConnName)
}

// NewConfigError initializes a ConfigError.
func NewConfigError(msg, cn string) *ConfigError {
return &ConfigError{
genericError: &genericError{Message: "Client error: " + msg, ConnName: cn},
}
}

// ConfigError represents an incorrect request by the user. Config errors
// usually indicate a semantic error (e.g., the instance connection name is
// malformated, the SQL instance does not support the requested IP type, etc.)
type ConfigError struct{ *genericError }

// NewRefreshError initializes a RefreshError.
func NewRefreshError(msg, cn string, err error) *RefreshError {
return &RefreshError{
genericError: &genericError{Message: msg, ConnName: cn},
Err: err,
}
}

// RefreshError means that an error occurred during the background
// refresh operation. In general, this is an unexpected error caused by
// an interaction with the API itself (e.g., missing certificates,
// invalid certificate encoding, region mismatch with the requested
// instance connection name, etc.).
type RefreshError struct {
*genericError
// Err is the underlying error and may be nil.
Err error
}

func (e *RefreshError) Error() string {
if e.Err == nil {
return fmt.Sprintf("Server error: %v", e.genericError)
}
return fmt.Sprintf("Server error: %v: %v", e.genericError, e.Err)
}

func (e *RefreshError) Unwrap() error { return e.Err }

// NewDialError initializes a DialError.
func NewDialError(msg, cn string, err error) *DialError {
return &DialError{
genericError: &genericError{Message: msg, ConnName: cn},
Err: err,
}
}

// DialError represents a problem that occurred when trying to dial a SQL
// instance (e.g., a failure to set the keep-alive property, a TLS handshake
// failure, a missing certificate, etc.)
type DialError struct {
*genericError
// Err is the underlying error and may be nil.
Err error
}

func (e *DialError) Error() string {
if e.Err == nil {
return fmt.Sprintf("Dial error: %v", e.genericError)
}
return fmt.Sprintf("Dial error: %v: %v", e.genericError, e.Err)
}

func (e *DialError) Unwrap() error { return e.Err }
70 changes: 70 additions & 0 deletions errtypes/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright 2021 Google LLC

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// https://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package errtypes_test

import (
"errors"
"testing"

"cloud.google.com/go/cloudsqlconn/errtypes"
)

func TestErrorFormatting(t *testing.T) {
tc := []struct {
desc string
err error
want string
}{
{
desc: "client error message",
err: errtypes.NewConfigError("error message", "proj:reg:inst"),
want: "Client error: error message (connection name = \"proj:reg:inst\")",
},
{
desc: "server error message without internal error",
err: errtypes.NewRefreshError("error message", "proj:reg:inst", nil),
want: "Server error: error message (connection name = \"proj:reg:inst\")",
},
{
desc: "server error message with internal error",
err: errtypes.NewRefreshError("error message", "proj:reg:inst", errors.New("inner-error")),
want: "Server error: error message (connection name = \"proj:reg:inst\"): inner-error",
},
{
desc: "Dial error without inner error",
err: errtypes.NewDialError(
"message",
"proj:reg:inst",
nil, // no error here
),
want: "Dial error: message (connection name = \"proj:reg:inst\")",
},
{
desc: "Dial error with inner error",
err: errtypes.NewDialError(
"message",
"proj:reg:inst",
errors.New("inner-error"),
),
want: "Dial error: message (connection name = \"proj:reg:inst\"): inner-error",
},
}

for _, c := range tc {
if got := c.err.Error(); got != c.want {
t.Errorf("%v, got = %q, want = %q", c.desc, got, c.want)
}
}
}
13 changes: 11 additions & 2 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"sync"
"time"

"cloud.google.com/go/cloudsqlconn/errtypes"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

Expand Down Expand Up @@ -54,7 +55,11 @@ func parseConnName(cn string) (connName, error) {
b := []byte(cn)
m := connNameRegex.FindSubmatch(b)
if m == nil {
return connName{}, fmt.Errorf("invalid instance connection name - expected PROJECT:REGION:ID")
err := errtypes.NewConfigError(
"invalid instance connection name, expected PROJECT:REGION:INSTANCE",
cn,
)
return connName{}, err
}

c := connName{
Expand Down Expand Up @@ -178,7 +183,11 @@ func (i *Instance) ConnectInfo(ctx context.Context, ipType string) (string, *tls
}
addr, ok := res.md.ipAddrs[ipType]
if !ok {
return "", nil, fmt.Errorf("instance '%s' does not have IP of type '%s'", i, ipType)
err := errtypes.NewConfigError(
fmt.Sprintf("instance does not have IP of type %q", ipType),
i.String(),
)
return "", nil, err
}
return addr, res.tlsCfg, nil
}
Expand Down
6 changes: 4 additions & 2 deletions internal/cloudsql/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"
"time"

"cloud.google.com/go/cloudsqlconn/errtypes"
"cloud.google.com/go/cloudsqlconn/internal/mock"
)

Expand Down Expand Up @@ -127,8 +128,9 @@ func TestConnectInfoErrors(t *testing.T) {
}

_, _, err = im.ConnectInfo(ctx, PublicIP)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("failed to retrieve connect info: %v", err)
var wantErr *errtypes.DialError
if !errors.As(err, &wantErr) {
t.Fatalf("when connect info fails, want = %T, got = %v", wantErr, err)
}

// when client asks for wrong IP address type
Expand Down
Loading

0 comments on commit 7441b71

Please sign in to comment.