diff --git a/.envrc.example b/.envrc.example index 889725bfc..8bf5e9e3d 100644 --- a/.envrc.example +++ b/.envrc.example @@ -15,3 +15,5 @@ export SQLSERVER_CONNECTION_NAME="project:region:instance" export SQLSERVER_USER="sqlserver-user" export SQLSERVER_PASS="sqlserver-password" export SQLSERVER_DB="sqlserver-db-name" + +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json diff --git a/.gitignore b/.gitignore index 8bd39d275..9c416baca 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,7 @@ # Compiled binary /cmd/cloud_sql_proxy/cloud_sql_proxy /cloud_sql_proxy +# v2 binary +/cloudsql-proxy + +/key.json diff --git a/cmd/root.go b/cmd/root.go index 8f9e697fd..dede5b78c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -92,7 +92,7 @@ connecting to Cloud SQL instances. It listens on a local port and forwards conne to your instance's IP address, providing a secure connection without having to manage any client SSL certificates.`, Args: func(cmd *cobra.Command, args []string) error { - err := parseConfig(c.conf, args) + err := parseConfig(cmd, c.conf, args) if err != nil { return err } @@ -108,6 +108,8 @@ any client SSL certificates.`, // Global-only flags cmd.PersistentFlags().StringVarP(&c.conf.Token, "token", "t", "", "Bearer token used for authorization.") + cmd.PersistentFlags().StringVarP(&c.conf.CredentialsFile, "credentials-file", "c", "", + "Path to a service account key to use for authentication.") // Global and per instance flags cmd.PersistentFlags().StringVarP(&c.conf.Addr, "address", "a", "127.0.0.1", @@ -119,7 +121,7 @@ any client SSL certificates.`, return c } -func parseConfig(conf *proxy.Config, args []string) error { +func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error { // If no instance connection names were provided, error. if len(args) == 0 { return newBadCommandError("missing instance_connection_name (e.g., project:region:instance)") @@ -129,6 +131,20 @@ func parseConfig(conf *proxy.Config, args []string) error { return newBadCommandError(fmt.Sprintf("not a valid IP address: %q", conf.Addr)) } + // If both token and credentials file were set, error. + if conf.Token != "" && conf.CredentialsFile != "" { + return newBadCommandError("Cannot specify --token and --credentials-file flags at the same time") + } + + switch { + case conf.Token != "": + cmd.Printf("Authorizing with the -token flag\n") + case conf.CredentialsFile != "": + cmd.Printf("Authorizing with the credentials file at %q\n", conf.CredentialsFile) + default: + cmd.Printf("Authorizing with Application Default Credentials") + } + var ics []proxy.InstanceConnConfig for _, a := range args { // Assume no query params initially @@ -211,8 +227,8 @@ func runSignalWrapper(cmd *Command) error { // Otherwise, initialize a new one. d := cmd.conf.Dialer if d == nil { - var err error opts := append(cmd.conf.DialerOpts(), cloudsqlconn.WithUserAgent(userAgent)) + var err error d, err = cloudsqlconn.NewDialer(ctx, opts...) if err != nil { shutdownCh <- fmt.Errorf("error initializing dialer: %v", err) diff --git a/cmd/root_test.go b/cmd/root_test.go index 93539879c..12457d4cf 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -119,6 +119,20 @@ func TestNewCommandArguments(t *testing.T) { Token: "MYCOOLTOKEN", }), }, + { + desc: "using the credentiale file flag", + args: []string{"--credentials-file", "/path/to/file", "proj:region:inst"}, + want: withDefaults(&proxy.Config{ + CredentialsFile: "/path/to/file", + }), + }, + { + desc: "using the (short) credentiale file flag", + args: []string{"-c", "/path/to/file", "proj:region:inst"}, + want: withDefaults(&proxy.Config{ + CredentialsFile: "/path/to/file", + }), + }, } for _, tc := range tcs { @@ -186,6 +200,12 @@ func TestNewCommandWithErrors(t *testing.T) { desc: "when the port query param is not a number", args: []string{"proj:region:inst?port=hi"}, }, + { + desc: "when both token and credentials file is set", + args: []string{ + "--token", "my-token", + "--credentials-file", "/path/to/file", "proj:region:inst"}, + }, } for _, tc := range tcs { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 9ea30bf88..93b723092 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -45,6 +45,9 @@ type Config struct { // Token is the Bearer token used for authorization. Token string + // CredentialsFile is the path to a service account key. + CredentialsFile string + // Addr is the address on which to bind all instances. Addr string @@ -61,18 +64,17 @@ type Config struct { Dialer cloudsql.Dialer } -// NewConfig initializes a Config struct using the default database engine -// ports. -func NewConfig() *Config { - return &Config{} -} - func (c *Config) DialerOpts() []cloudsqlconn.Option { var opts []cloudsqlconn.Option - if c.Token != "" { + switch { + case c.Token != "": opts = append(opts, cloudsqlconn.WithTokenSource( oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}), )) + case c.CredentialsFile != "": + opts = append(opts, cloudsqlconn.WithCredentialsFile( + c.CredentialsFile, + )) } return opts } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2be19803b..34a490dd2 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -167,30 +167,3 @@ func TestClientInitialization(t *testing.T) { }) } } - -func TestConfigDialerOpts(t *testing.T) { - tcs := []struct { - desc string - config proxy.Config - wantLen int - }{ - { - desc: "when there are no options", - config: proxy.Config{}, - wantLen: 0, - }, - { - desc: "when a token is present", - config: proxy.Config{Token: "my-token"}, - wantLen: 1, - }, - } - - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - if got := tc.config.DialerOpts(); tc.wantLen != len(got) { - t.Errorf("want len = %v, got = %v", tc.wantLen, len(got)) - } - }) - } -} diff --git a/tests/alldb_test.go b/tests/alldb_test.go index a65404d9a..5d90d4d48 100644 --- a/tests/alldb_test.go +++ b/tests/alldb_test.go @@ -20,6 +20,7 @@ import ( "fmt" "net/http" "testing" + "time" ) // requireAllVars skips the given test if at least one environment variable is undefined. @@ -43,13 +44,17 @@ func TestMultiInstanceDial(t *testing.T) { t.Skip("skipping Health Check integration tests") } requireAllVars(t) - ctx := context.Background() - - var args []string - args = append(args, fmt.Sprintf("-instances=%s=tcp:%d,%s=tcp:%d,%s=tcp:%d", *mysqlConnName, mysqlPort, *postgresConnName, postgresPort, *sqlserverConnName, sqlserverPort)) - args = append(args, "-use_http_health_check") + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() // Start the proxy. + args := []string{ + // This test doesn't care what the instance port is, so use "0" which + // means, let the runtime pick a random port. + fmt.Sprintf("-instances=%s=tcp:0,%s=tcp:0,%s=tcp:0", + *mysqlConnName, *postgresConnName, *sqlserverConnName), + "-use_http_health_check", + } p, err := StartProxy(ctx, args...) if err != nil { t.Fatalf("unable to start proxy: %v", err) diff --git a/testsV2/connection_test.go b/testsV2/connection_test.go index 52e41eaab..63e8a8bdb 100644 --- a/testsV2/connection_test.go +++ b/testsV2/connection_test.go @@ -17,15 +17,45 @@ package tests import ( "context" "database/sql" + "os" "testing" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "google.golang.org/api/sqladmin/v1" ) -// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test. -func proxyConnTest(t *testing.T, connName, driver, dsn string, port int, dir string) { - ctx := context.Background() +const connTestTimeout = time.Minute - args := []string{connName} +// removeAuthEnvVar retrieves an OAuth2 token and a path to a service account key +// and then unsets GOOGLE_APPLICATION_CREDENTIALS. It returns a cleanup function +// that restores the original setup. +func removeAuthEnvVar(t *testing.T) (*oauth2.Token, string, func()) { + ts, err := google.DefaultTokenSource(context.Background(), sqladmin.SqlserviceAdminScope) + if err != nil { + t.Errorf("failed to resolve token source: %v", err) + } + tok, err := ts.Token() + if err != nil { + t.Errorf("failed to get token: %v", err) + } + path, ok := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS") + if !ok { + t.Fatalf("GOOGLE_APPLICATION_CREDENTIALS was not set in the environment") + } + if err := os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS"); err != nil { + t.Fatalf("failed to unset GOOGLE_APPLICATION_CREDENTIALS") + } + return tok, path, func() { + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", path) + } +} +// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test. +func proxyConnTest(t *testing.T, args []string, driver, dsn string) { + ctx, cancel := context.WithTimeout(context.Background(), connTestTimeout) + defer cancel() // Start the proxy p, err := StartProxy(ctx, args...) if err != nil { diff --git a/testsV2/mysql_test.go b/testsV2/mysql_test.go index 88b8c44be..2246b2861 100644 --- a/testsV2/mysql_test.go +++ b/testsV2/mysql_test.go @@ -27,12 +27,10 @@ var ( mysqlConnName = flag.String("mysql_conn_name", os.Getenv("MYSQL_CONNECTION_NAME"), "Cloud SQL MYSQL instance connection name, in the form of 'project:region:instance'.") mysqlUser = flag.String("mysql_user", os.Getenv("MYSQL_USER"), "Name of database user.") mysqlPass = flag.String("mysql_pass", os.Getenv("MYSQL_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).") - mysqlDb = flag.String("mysql_db", os.Getenv("MYSQL_DB"), "Name of the database to connect to.") - - mysqlPort = 3306 + mysqlDB = flag.String("mysql_db", os.Getenv("MYSQL_DB"), "Name of the database to connect to.") ) -func requireMysqlVars(t *testing.T) { +func requireMySQLVars(t *testing.T) { switch "" { case *mysqlConnName: t.Fatal("'mysql_conn_name' not set") @@ -40,23 +38,65 @@ func requireMysqlVars(t *testing.T) { t.Fatal("'mysql_user' not set") case *mysqlPass: t.Fatal("'mysql_pass' not set") - case *mysqlDb: + case *mysqlDB: t.Fatal("'mysql_db' not set") } } -func TestMysqlTcp(t *testing.T) { +func TestMySQLTCP(t *testing.T) { + if testing.Short() { + t.Skip("skipping MySQL integration tests") + } + requireMySQLVars(t) + cfg := mysql.Config{ + User: *mysqlUser, + Passwd: *mysqlPass, + DBName: *mysqlDB, + AllowNativePasswords: true, + Addr: "127.0.0.1:3306", + Net: "tcp", + } + proxyConnTest(t, []string{*mysqlConnName}, "mysql", cfg.FormatDSN()) +} + +func TestMySQLAuthWithToken(t *testing.T) { if testing.Short() { t.Skip("skipping MySQL integration tests") } - requireMysqlVars(t) + requireMySQLVars(t) + tok, _, cleanup := removeAuthEnvVar(t) + defer cleanup() + + cfg := mysql.Config{ + User: *mysqlUser, + Passwd: *mysqlPass, + DBName: *mysqlDB, + AllowNativePasswords: true, + Addr: "127.0.0.1:3306", + Net: "tcp", + } + proxyConnTest(t, + []string{"--token", tok.AccessToken, *mysqlConnName}, + "mysql", cfg.FormatDSN()) +} + +func TestMySQLAuthWithCredentialsFile(t *testing.T) { + if testing.Short() { + t.Skip("skipping MySQL integration tests") + } + requireMySQLVars(t) + _, path, cleanup := removeAuthEnvVar(t) + defer cleanup() + cfg := mysql.Config{ User: *mysqlUser, Passwd: *mysqlPass, - DBName: *mysqlDb, + DBName: *mysqlDB, AllowNativePasswords: true, Addr: "127.0.0.1:3306", Net: "tcp", } - proxyConnTest(t, *mysqlConnName, "mysql", cfg.FormatDSN(), mysqlPort, "") + proxyConnTest(t, + []string{"--credentials-file", path, *mysqlConnName}, + "mysql", cfg.FormatDSN()) } diff --git a/testsV2/postgres_test.go b/testsV2/postgres_test.go index 3bb86166a..f28f17b5d 100644 --- a/testsV2/postgres_test.go +++ b/testsV2/postgres_test.go @@ -29,11 +29,9 @@ var ( postgresConnName = flag.String("postgres_conn_name", os.Getenv("POSTGRES_CONNECTION_NAME"), "Cloud SQL Postgres instance connection name, in the form of 'project:region:instance'.") postgresUser = flag.String("postgres_user", os.Getenv("POSTGRES_USER"), "Name of database user.") postgresPass = flag.String("postgres_pass", os.Getenv("POSTGRES_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).") - postgresDb = flag.String("postgres_db", os.Getenv("POSTGRES_DB"), "Name of the database to connect to.") + postgresDB = flag.String("postgres_db", os.Getenv("POSTGRES_DB"), "Name of the database to connect to.") postgresIAMUser = flag.String("postgres_user_iam", os.Getenv("POSTGRES_USER_IAM"), "Name of database user configured with IAM DB Authentication.") - - postgresPort = 5432 ) func requirePostgresVars(t *testing.T) { @@ -44,17 +42,47 @@ func requirePostgresVars(t *testing.T) { t.Fatal("'postgres_user' not set") case *postgresPass: t.Fatal("'postgres_pass' not set") - case *postgresDb: + case *postgresDB: t.Fatal("'postgres_db' not set") } } -func TestPostgresTcp(t *testing.T) { +func TestPostgresTCP(t *testing.T) { + if testing.Short() { + t.Skip("skipping Postgres integration tests") + } + requirePostgresVars(t) + + dsn := fmt.Sprintf("user=%s password=%s database=%s sslmode=disable", *postgresUser, *postgresPass, *postgresDB) + proxyConnTest(t, []string{*postgresConnName}, "postgres", dsn) +} + +func TestPostgresAuthWithToken(t *testing.T) { + if testing.Short() { + t.Skip("skipping Postgres integration tests") + } + requirePostgresVars(t) + tok, _, cleanup := removeAuthEnvVar(t) + defer cleanup() + + dsn := fmt.Sprintf("user=%s password=%s database=%s sslmode=disable", + *postgresUser, *postgresPass, *postgresDB) + proxyConnTest(t, + []string{"--token", tok.AccessToken, *postgresConnName}, + "postgres", dsn) +} + +func TestPostgresAuthWithCredentialsFile(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests") } requirePostgresVars(t) + _, path, cleanup := removeAuthEnvVar(t) + defer cleanup() - dsn := fmt.Sprintf("user=%s password=%s database=%s sslmode=disable", *postgresUser, *postgresPass, *postgresDb) - proxyConnTest(t, *postgresConnName, "postgres", dsn, postgresPort, "") + dsn := fmt.Sprintf("user=%s password=%s database=%s sslmode=disable", + *postgresUser, *postgresPass, *postgresDB) + proxyConnTest(t, + []string{"--credentials-file", path, *postgresConnName}, + "postgres", dsn) } diff --git a/testsV2/sqlserver_test.go b/testsV2/sqlserver_test.go index 8ab9875ae..3ba683391 100644 --- a/testsV2/sqlserver_test.go +++ b/testsV2/sqlserver_test.go @@ -28,12 +28,10 @@ var ( sqlserverConnName = flag.String("sqlserver_conn_name", os.Getenv("SQLSERVER_CONNECTION_NAME"), "Cloud SQL SqlServer instance connection name, in the form of 'project:region:instance'.") sqlserverUser = flag.String("sqlserver_user", os.Getenv("SQLSERVER_USER"), "Name of database user.") sqlserverPass = flag.String("sqlserver_pass", os.Getenv("SQLSERVER_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).") - sqlserverDb = flag.String("sqlserver_db", os.Getenv("SQLSERVER_DB"), "Name of the database to connect to.") - - sqlserverPort = 1433 + sqlserverDB = flag.String("sqlserver_db", os.Getenv("SQLSERVER_DB"), "Name of the database to connect to.") ) -func requireSqlserverVars(t *testing.T) { +func requireSQLServerVars(t *testing.T) { switch "" { case *sqlserverConnName: t.Fatal("'sqlserver_conn_name' not set") @@ -41,17 +39,48 @@ func requireSqlserverVars(t *testing.T) { t.Fatal("'sqlserver_user' not set") case *sqlserverPass: t.Fatal("'sqlserver_pass' not set") - case *sqlserverDb: + case *sqlserverDB: t.Fatal("'sqlserver_db' not set") } } -func TestSqlServerTcp(t *testing.T) { +func TestSQLServerTCP(t *testing.T) { + if testing.Short() { + t.Skip("skipping SQL Server integration tests") + } + requireSQLServerVars(t) + + dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", + *sqlserverUser, *sqlserverPass, *sqlserverDB) + proxyConnTest(t, []string{*sqlserverConnName}, "sqlserver", dsn) +} + +func TestSQLServerAuthWithToken(t *testing.T) { + if testing.Short() { + t.Skip("skipping SQL Server integration tests") + } + requireSQLServerVars(t) + tok, _, cleanup := removeAuthEnvVar(t) + defer cleanup() + + dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", + *sqlserverUser, *sqlserverPass, *sqlserverDB) + proxyConnTest(t, + []string{"--token", tok.AccessToken, *sqlserverConnName}, + "sqlserver", dsn) +} + +func TestSQLServerAuthWithCredentialsFile(t *testing.T) { if testing.Short() { t.Skip("skipping SQL Server integration tests") } - requireSqlserverVars(t) + requireSQLServerVars(t) + _, path, cleanup := removeAuthEnvVar(t) + defer cleanup() - dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", *sqlserverUser, *sqlserverPass, *sqlserverDb) - proxyConnTest(t, *sqlserverConnName, "sqlserver", dsn, sqlserverPort, "") + dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", + *sqlserverUser, *sqlserverPass, *sqlserverDB) + proxyConnTest(t, + []string{"--credentials-file", path, *sqlserverConnName}, + "sqlserver", dsn) }