diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index fd5a68fbc..61d95c7f2 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -510,19 +510,19 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * for _, inst := range conf.Instances { m, err := c.newSocketMount(ctx, conf, pc, inst) if err != nil { - for _, m := range mnts { - mErr := m.Close() - if mErr != nil { - l.Errorf("failed to close mount: %v", mErr) - } + c.logger.Errorf("[%v] Unable to mount socket: %v", inst.Name, err) + if networkType(conf, inst) == "unix" { + continue } - return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err) } l.Infof("[%s] Listening on %s", inst.Name, m.Addr()) mnts = append(mnts, m) } c.mnts = mnts + if len(mnts) == 0 { + return nil, fmt.Errorf("No active instance to mount") + } return c, nil } @@ -747,6 +747,14 @@ type socketMount struct { dialOpts []cloudsqlconn.DialOption } +func networkType(conf *Config, inst InstanceConnConfig) string { + if (conf.UnixSocket == "" && inst.UnixSocket == "" && inst.UnixSocketPath == "") || + (inst.Addr != "" || inst.Port != 0) { + return "tcp" + } + return "unix" +} + func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfig, inst InstanceConnConfig) (*socketMount, error) { var ( // network is one of "tcp" or "unix" @@ -765,8 +773,7 @@ func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfi // instance) // use a TCP listener. // Otherwise, use a Unix socket. - if (conf.UnixSocket == "" && inst.UnixSocket == "" && inst.UnixSocketPath == "") || - (inst.Addr != "" || inst.Port != 0) { + if networkType(conf, inst) == "tcp" { network = "tcp" a := conf.Addr @@ -784,7 +791,6 @@ func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfi version, err := c.dialer.EngineVersion(ctx, inst.Name) if err != nil { c.logger.Errorf("could not resolve version for %q: %v", inst.Name, err) - return nil, err } np = pc.nextDBPort(version) } @@ -801,6 +807,8 @@ func (c *Client) newSocketMount(ctx context.Context, conf *Config, pc *portConfi address, err = newUnixSocketMount(inst, conf.UnixSocket, strings.HasPrefix(version, "POSTGRES")) if err != nil { + c.logger.Errorf("could not mount unix socket %q for %q: %v", + conf.UnixSocket, inst.Name, err) return nil, err } } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 8d509295e..69f9f276a 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -17,6 +17,7 @@ package proxy_test import ( "context" "errors" + "fmt" "io" "net" "os" @@ -83,6 +84,8 @@ func (f *fakeDialer) EngineVersion(_ context.Context, inst string) (string, erro return "MYSQL_8_0", nil case strings.Contains(inst, "sqlserver"): return "SQLSERVER_2019_STANDARD", nil + case strings.Contains(inst, "fakeserver"): + return "", fmt.Errorf("non existing server") default: return "POSTGRES_14", nil } @@ -306,6 +309,42 @@ func TestClientInitialization(t *testing.T) { filepath.Join(testUnixSocketPathPg), }, }, + { + desc: "without TCP port or unix socket for non functional instance", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:fakeserver"}, + }, + }, + }, + { + desc: "with TCP port for non functional instance", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + {Name: "proj:region:fakeserver", Port: 50000}, + }, + }, + wantTCPAddrs: []string{ + "127.0.0.1:50000", + }, + }, + } + + unixTcs := []struct { + desc string + in *proxy.Config + }{ + { + desc: "with unix socket for non functional instance", + in: &proxy.Config{ + Instances: []proxy.InstanceConnConfig{ + { + Name: "proj:region:fakeserver", + UnixSocketPath: testUnixSocketPath, + }, + }, + }, + }, } for _, tc := range tcs { @@ -344,6 +383,15 @@ func TestClientInitialization(t *testing.T) { } }) } + + for _, tc := range unixTcs { + t.Run(tc.desc, func(t *testing.T) { + _, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) + if err == nil { + t.Fatalf("want non nil error, got = %v", err) + } + }) + } } func TestClientLimitsMaxConnections(t *testing.T) {